From 721214249b5d610fd0eb30dbead929d9f7664568 Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Thu, 14 May 2026 01:39:14 -0400 Subject: [PATCH 01/55] SYCL: fix multi-GPU system RAM exhaustion by using Level Zero allocations (llama/21597) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL: fix multi-GPU system RAM exhaustion by using Level Zero allocations Replace sycl::malloc_device with zeMemAllocDevice for GPU memory allocation in the SYCL backend. sycl::malloc_device triggers the xe kernel driver's DMA-buf/TTM path which mirrors every VRAM allocation 1:1 in system RAM. zeMemAllocDevice uses the SVM/P2P path with no host staging. On a dual Intel Arc Pro B70 system (64GB VRAM, 64GB RAM), a 15.6 GiB model consumed 60 GiB of system RAM via sycl::malloc_device, causing OOM crashes. With zeMemAllocDevice, the same workload uses ~6.7 GiB of system RAM with no performance regression. All Level Zero calls include automatic fallback to the original SYCL allocation path if Level Zero interop is unavailable. * SYCL: address review feedback - remove try/catch, check device types, deduplicate - Remove try/catch from malloc/free/memcpy helpers, check backend and device type upfront instead (ggml_sycl_is_level_zero, ggml_sycl_is_dgpu) - Move shared helpers (is_level_zero, is_dgpu, free_device) to common.cpp and declare in common.hpp to eliminate code duplication - Use SYCL_CHECK(CHECK_TRY_ERROR()) for fallback sycl::free calls - Guard dev2dev_memcpy L0 path to dGPU-to-dGPU only, preserving the host-staged path for iGPU-to-dGPU transfers - Add Windows Level Zero SDK path detection (LEVEL_ZERO_V1_SDK_PATH) in CMakeLists.txt (co-authored with @arthw) * SYCL: add build/runtime flags for Level Zero, address review feedback Implements the architecture suggested by @arthw: compile-time and runtime flags to cleanly separate Level Zero and SYCL memory API paths. - Add GGML_SYCL_SUPPORT_LEVEL_ZERO cmake option (default ON). All Level Zero code is wrapped in #ifdef so the build works on systems without the Level Zero SDK installed (e.g. CPU-only CI servers). Both the loader library and headers are checked before enabling. - Add GGML_SYCL_ENABLE_LEVEL_ZERO runtime env var (default 1). Controls whether Level Zero or SYCL memory APIs are used. Only one API style is used per session, no mixing. If Level Zero is enabled but the devices don't support the Level Zero backend, it auto-disables with a warning. - Remove Level Zero code from dpct_malloc. It was unused (dpct::device_memory is not called anywhere in the backend) and used try/catch for flow control. - Update SYCL.md with documentation for both new parameters. Tested on Intel Arc Pro B70 (32GB), single-GPU and dual-GPU, with both GGML_SYCL_SUPPORT_LEVEL_ZERO=ON and OFF builds. AI-assisted development (Claude). Code reviewed and tested on my hardware. * SYCL: unify Level Zero malloc/free call sites, address review feedback Move ggml_sycl_malloc_device to common.cpp alongside ggml_sycl_free_device. Both functions are now unconditionally available — Level Zero code is #ifdef'd inside the functions, not at call sites. All call sites use uniform SYCL_CHECK(CHECK_TRY_ERROR()) wrapping with no #ifdef blocks. Addresses arthw's review: wrap all malloc/free in SYCL_CHECK for stack traces on failure, eliminate duplicated #ifdef/else patterns at 6 call sites (-29 lines net). Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: add Level Zero SDK to CI, fix device check and missed alloc paths Add Level Zero SDK installation to Ubuntu and Windows SYCL CI jobs so the Level Zero code path is compiled and tested in CI. Fix two bugs found during extended dual-GPU testing (no ONEAPI_DEVICE_SELECTOR set): - The Level Zero backend check was iterating all SYCL devices including CPU. The OpenCL CPU device caused Level Zero to be disabled for the GPUs, defeating the fix on multi-GPU systems. Added is_gpu() filter so only GPU devices are checked. - sycl_ext_malloc_device/sycl_ext_free (tensor reorder temp buffers) were still calling sycl::malloc/sycl::free directly, bypassing the Level Zero path. Routed through ggml_sycl_malloc_device/free_device for consistency with the other device memory call sites. Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: address arthw review feedback on Level Zero memory API structure - Move ggml_sycl_malloc_device to static function in ggml-sycl.cpp; only ggml_sycl_free_device (used by common.cpp) stays in common.cpp - Switch both helpers to use g_ggml_sycl_enable_level_zero global instead of per-call queue backend checks - Remove #ifdef wrapper from global definition; always declare at 0, add #else branch in init block so it stays 0 when L0 not compiled in - Update init loop comment to explain GPU-only device check - CMakeLists: message(STATUS) before the if block; align option wording AI-assisted implementation. Reviewed and tested on dual Intel Arc Pro B70 (32 GB each): test-backend-ops OK on both GPUs, single/dual-GPU Q4_K_M and Q8_0 bench correct, zeMemAllocDevice GTT delta confirmed <5 MiB per 4 GiB allocation (vs ~4 GiB shadow with sycl::malloc_device). Co-Authored-By: Claude Sonnet 4.6 * SYCL: remove unused cstdio/cstdlib includes from common.cpp Leftover from the deleted ggml_sycl_queue_supports_level_zero helper. Co-authored-by: Claude Sonnet 4.6 * Apply suggestions from code review Co-authored-by: Neo Zhang * SYCL: preserve Level Zero allocation path during early malloc * ci: fix Level Zero package conflict in Intel Docker build * ci: find Level Zero loader in oneAPI package step * ci: allow Windows SYCL package without Level Zero DLL --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Neo Zhang --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-sycl/CMakeLists.txt | 29 +++++++++ ggml/src/ggml-sycl/common.cpp | 76 +++++++++++++++++++++++- ggml/src/ggml-sycl/common.hpp | 4 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 98 +++++++++++++++++++++++++------ 5 files changed, 187 insertions(+), 21 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4e65cd68b4e..bdeca34bf9f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -249,6 +249,7 @@ option(GGML_SYCL "ggml: use SYCL" option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) +option(GGML_SYCL_SUPPORT_LEVEL_ZERO "ggml: use Level Zero API in SYCL backend" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 8f44c6ed080..180de92202d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -39,6 +39,18 @@ if (WIN32) set(CMAKE_CXX_COMPILER "icx") set(CMAKE_CXX_COMPILER_ID "IntelLLVM") endif() + # Level Zero SDK path for Windows (only when GGML_SYCL_SUPPORT_LEVEL_ZERO is enabled) + if(GGML_SYCL_SUPPORT_LEVEL_ZERO) + if(DEFINED ENV{LEVEL_ZERO_V1_SDK_PATH}) + set(LEVEL_ZERO_V1_SDK_PATH $ENV{LEVEL_ZERO_V1_SDK_PATH}) + if(EXISTS "${LEVEL_ZERO_V1_SDK_PATH}") + target_include_directories(ggml-sycl PRIVATE "${LEVEL_ZERO_V1_SDK_PATH}/include") + set(LEVEL_ZERO_V1_SDK_LIB_PATH "${LEVEL_ZERO_V1_SDK_PATH}/lib") + else() + message(WARNING "LEVEL_ZERO_V1_SDK_PATH set but folder not found: ${LEVEL_ZERO_V1_SDK_PATH}") + endif() + endif() + endif() endif() macro(detect_and_find_package package_name) @@ -93,6 +105,23 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") +message(STATUS "GGML_SYCL_SUPPORT_LEVEL_ZERO ${GGML_SYCL_SUPPORT_LEVEL_ZERO}") +if (GGML_SYCL_SUPPORT_LEVEL_ZERO) + # Link against Level Zero loader for direct device memory allocation. + # Avoids sycl::malloc_device triggering DMA-buf/TTM system RAM staging + # in the xe kernel driver during multi-GPU inference. + find_path(LEVEL_ZERO_INCLUDE_DIR level_zero/ze_api.h HINTS ${ONEAPI_ROOT}/include ${LEVEL_ZERO_V1_SDK_PATH}/include) + find_library(ZE_LOADER_LIB ze_loader HINTS ${ONEAPI_ROOT}/lib ${LEVEL_ZERO_V1_SDK_LIB_PATH} ENV LD_LIBRARY_PATH) + if(ZE_LOADER_LIB AND LEVEL_ZERO_INCLUDE_DIR) + target_link_libraries(ggml-sycl PRIVATE ${ZE_LOADER_LIB}) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_SUPPORT_LEVEL_ZERO) + message(STATUS "Level Zero loader found: ${ZE_LOADER_LIB}") + message(STATUS "Level Zero headers found: ${LEVEL_ZERO_INCLUDE_DIR}") + else() + message(WARNING "Level Zero loader or headers not found, Level Zero support disabled") + endif() +endif() + # Link against oneDNN set(GGML_SYCL_DNNL 0) if(GGML_SYCL_DNN) diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 05fd5ef46c7..ae08abad81b 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -11,6 +11,10 @@ // #include "common.hpp" +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #include "ggml-backend-impl.h" #include "ggml-impl.h" @@ -55,6 +59,20 @@ bool gpu_has_xmx(sycl::device &dev) { return dev.has(sycl::aspect::ext_intel_matrix); } +static int ggml_sycl_get_env(const char *env_name, int default_val) { + char *user_device_string = getenv(env_name); + int user_number = default_val; + + unsigned n; + if (user_device_string != NULL && + sscanf(user_device_string, " %u", &n) == 1) { + user_number = (int)n; + } else { + user_number = default_val; + } + return user_number; +} + int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { const int64_t max_range = std::numeric_limits::max(); int64_t sycl_down_blk_size = block_size; @@ -66,6 +84,61 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block return sycl_down_blk_size; } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_use_level_zero_device_alloc(sycl::queue &q) { + return ggml_sycl_get_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1) && + q.get_device().is_gpu() && + q.get_backend() == sycl::backend::ext_oneapi_level_zero; +} +#endif + +// Use Level Zero zeMemAllocDevice to avoid sycl::malloc_device triggering +// DMA-buf/TTM system RAM staging in the xe kernel driver during multi-GPU inference. +// The decision is made from the queue and runtime env because large buffers can be +// allocated before ggml_check_sycl() initializes g_ggml_sycl_enable_level_zero. +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + void *ptr = nullptr; + auto ze_ctx = sycl::get_native(q.get_context()); + auto ze_dev = sycl::get_native(q.get_device()); +#ifdef ZE_RELAXED_ALLOCATION_LIMITS_EXP_NAME + ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = { + ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC, + nullptr, + ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE, + }; + ze_device_mem_alloc_desc_t alloc_desc = { + ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + &relaxed_desc, + 0, + 0, + }; +#else + ze_device_mem_alloc_desc_t alloc_desc = {ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, nullptr, 0, 0}; +#endif + ze_result_t r = zeMemAllocDevice(ze_ctx, &alloc_desc, size, 64, ze_dev, &ptr); + if (r == ZE_RESULT_SUCCESS && ptr) { + return ptr; + } + return nullptr; + } +#endif + return sycl::malloc_device(size, q); +} + +void ggml_sycl_free_device(void *ptr, sycl::queue &q) { + if (!ptr) return; +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + auto ze_ctx = sycl::get_native(q.get_context()); + zeMemFree(ze_ctx, ptr); + return; + } +#endif + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, q))); +} + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { for (int i = 0; i < ggml_sycl_info().device_count; ++i) { for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { @@ -75,8 +148,7 @@ void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector str } if (extra->data_device[i] != nullptr && streams.size()>0) { ggml_sycl_set_device(i); - SYCL_CHECK( - CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(extra->data_device[i], *(streams[i])))); } } delete extra; diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index eec36e8db9a..96bc1c98bd9 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -310,6 +310,10 @@ struct ggml_tensor_extra_gpu { optimize_feature optimized_feature; }; +extern int g_ggml_sycl_enable_level_zero; +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q); +void ggml_sycl_free_device(void *ptr, sycl::queue &q); + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); namespace sycl_ex = sycl::ext::oneapi::experimental; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 57cc4ffb6f7..f5d10b56de0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -30,6 +30,10 @@ #include #include +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include #endif @@ -68,6 +72,7 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_enable_level_zero = 0; int g_ggml_sycl_enable_flash_attention = 1; @@ -223,6 +228,27 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1); +#else + g_ggml_sycl_enable_level_zero = 0; +#endif + if (g_ggml_sycl_enable_level_zero) { + // Verify all GPU devices use the Level Zero backend before enabling L0 APIs. + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); i++) { + auto & q = dpct::dev_mgr::instance().get_device(i).default_queue(); + if (!q.get_device().is_gpu()) { + continue; + } + if (q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + g_ggml_sycl_enable_level_zero = 0; + break; + } + } + } #ifdef SYCL_FLASH_ATTN g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); @@ -253,6 +279,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n"); #endif +#if defined(GGML_SYCL_SUPPORT_LEVEL_ZERO) + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); +#endif GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); @@ -262,6 +293,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); #endif +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: %d\n", g_ggml_sycl_enable_level_zero); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: Level Zero disabled by compile flag\n"); +#endif #if GGML_SYCL_DNNL GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else @@ -371,7 +407,7 @@ struct ggml_backend_sycl_buffer_context { ~ggml_backend_sycl_buffer_context() { if (dev_ptr != nullptr) { ggml_sycl_set_device(device); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(dev_ptr, *stream))); } //release extra used by tensors @@ -504,8 +540,43 @@ catch (sycl::exception const &exc) { std::exit(1); } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_is_l0_discrete_gpu(sycl::queue &q) { + if (!q.get_device().is_gpu() || q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + return false; + } + + ze_device_handle_t ze_dev = sycl::get_native(q.get_device()); + ze_device_properties_t props = {}; + props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + ze_result_t r = zeDeviceGetProperties(ze_dev, &props); + return r == ZE_RESULT_SUCCESS && !(props.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED); +} +#endif + static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + // Use Level Zero direct copy for dGPU-to-dGPU transfers. + const bool l0_copy_supported = + ggml_sycl_is_l0_discrete_gpu(q_dst) && ggml_sycl_is_l0_discrete_gpu(q_src); + if (g_ggml_sycl_enable_level_zero && l0_copy_supported) { + auto ze_ctx = sycl::get_native(q_dst.get_context()); + auto ze_dev = sycl::get_native(q_dst.get_device()); + ze_command_queue_desc_t cq_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, 0, 0, + 0, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL}; + ze_command_list_handle_t cl; + ze_result_t r = zeCommandListCreateImmediate(ze_ctx, ze_dev, &cq_desc, &cl); + if (r == ZE_RESULT_SUCCESS) { + r = zeCommandListAppendMemoryCopy(cl, ptr_dst, ptr_src, size, nullptr, 0, nullptr); + zeCommandListDestroy(cl); + if (r == ZE_RESULT_SUCCESS) { + return; + } + } + } +#endif + // Host-staged copy char *host_buf = (char *)malloc(size); q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); q_dst.memcpy((char *)ptr_dst, host_buf, size).wait(); @@ -675,8 +746,7 @@ ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 void * dev_ptr; - SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)ggml_sycl_malloc_device(size, *stream))); if (!dev_ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); return nullptr; @@ -917,18 +987,10 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - // FIXME: do not crash if SYCL Buffer alloc fails - // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_sycl_set_device(i); const queue_ptr stream = ctx->streams[i]; char * buf; - /* - DPCT1009:208: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)ggml_sycl_malloc_device(size, *stream))); if (!buf) { char err_buf[1024]; snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); @@ -1306,7 +1368,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(b.ptr, *qptr))); pool_size -= b.size; } } @@ -1374,9 +1436,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( - look_ahead_size, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *)ggml_sycl_malloc_device(look_ahead_size, *qptr))); if (!ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); return nullptr; @@ -1404,7 +1464,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } } GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(ptr, *qptr))); pool_size -= size; } }; @@ -3405,7 +3465,7 @@ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - return sycl::malloc(size, *stream, sycl::usm::alloc::device); + return ggml_sycl_malloc_device(size, *stream); } static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { @@ -3419,7 +3479,7 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - sycl::free(ptr, *stream); + ggml_sycl_free_device(ptr, *stream); } // RAII wrapper for temporary reorder buffers with optional host memory fallback. From 610058ae17574623a169eb71eb34949c25996f58 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 14 May 2026 10:36:54 +0200 Subject: [PATCH 02/55] vulkan: fix matmul integer pipeline selection (llama/23005) * vulkan: fix matmul integer pipeline selection * gate pipeline creation with the right bools --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a0a556206d5..8c4cf9ef1db 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3954,13 +3954,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ - if (device->mul_mat ## ID ## _l[TYPE]) { \ + if (device->mul_mat ## ID ## _l_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _m[TYPE]) { \ + if (device->mul_mat ## ID ## _m_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _s[TYPE]) { \ + if (device->mul_mat ## ID ## _s_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ @@ -4131,11 +4131,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ + if (device->mul_mat ## ID ## _l_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ + if (device->mul_mat ## ID ## _m_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ + if (device->mul_mat ## ID ## _s_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); @@ -5716,12 +5716,12 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } - device->mul_mat_l_int[i] = true; - device->mul_mat_m_int[i] = true; - device->mul_mat_s_int[i] = true; - device->mul_mat_id_l_int[i] = true; - device->mul_mat_id_m_int[i] = true; - device->mul_mat_id_s_int[i] = true; + device->mul_mat_l_int[i] = device->mul_mat_l[i]; + device->mul_mat_m_int[i] = device->mul_mat_m[i]; + device->mul_mat_s_int[i] = device->mul_mat_s[i]; + device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i]; + device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i]; + device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i]; } From 6c502076aabfa753986b4fbadef10494f00ae426 Mon Sep 17 00:00:00 2001 From: alex-spacemit Date: Thu, 14 May 2026 17:39:30 +0800 Subject: [PATCH 03/55] ggml-cpu: Add IME2 Instruction Support for the SpacemiT Backend (llama/22863) --- ggml/src/ggml-cpu/CMakeLists.txt | 13 + ggml/src/ggml-cpu/cmake/FindSMTIME.cmake | 32 + ggml/src/ggml-cpu/ggml-cpu.c | 12 + ggml/src/ggml-cpu/spacemit/ime.cpp | 2089 ++++-- ggml/src/ggml-cpu/spacemit/ime.h | 8 + ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp | 3363 ++-------- ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp | 5768 +++++++++++++++++ ggml/src/ggml-cpu/spacemit/ime_env.cpp | 320 + ggml/src/ggml-cpu/spacemit/ime_env.h | 55 + ggml/src/ggml-cpu/spacemit/ime_kernels.h | 201 +- ggml/src/ggml-cpu/spacemit/repack.cpp | 1795 +++++ ggml/src/ggml-cpu/spacemit/repack.h | 14 + ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp | 3178 +++++++++ ggml/src/ggml-cpu/spacemit/rvv_kernels.h | 95 + ggml/src/ggml-cpu/spacemit/spine_barrier.h | 34 + ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp | 760 +++ ggml/src/ggml-cpu/spacemit/spine_mem_pool.h | 32 + ggml/src/ggml-cpu/spacemit/spine_tcm.h | 409 ++ 18 files changed, 14706 insertions(+), 3472 deletions(-) create mode 100644 ggml/src/ggml-cpu/cmake/FindSMTIME.cmake create mode 100644 ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ime_env.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ime_env.h create mode 100644 ggml/src/ggml-cpu/spacemit/repack.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/repack.h create mode 100644 ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/rvv_kernels.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_barrier.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/spine_mem_pool.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_tcm.h diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 869c7b238bf..f3eccff7d72 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -450,12 +450,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/repack.cpp ) if (GGML_CPU_RISCV64_SPACEMIT) + include(ggml-cpu/cmake/FindSMTIME.cmake) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) list(APPEND GGML_CPU_SOURCES ggml-cpu/spacemit/ime.cpp ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/spine_mem_pool.cpp + ggml-cpu/spacemit/spine_mem_pool.h + ggml-cpu/spacemit/repack.cpp + ggml-cpu/spacemit/repack.h + ggml-cpu/spacemit/ime_env.cpp + ggml-cpu/spacemit/ime_env.h ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime2_kernels.cpp ggml-cpu/spacemit/ime_kernels.h + ggml-cpu/spacemit/rvv_kernels.cpp + ggml-cpu/spacemit/rvv_kernels.h ) endif() if(NOT GGML_CPU_ALL_VARIANTS) @@ -485,6 +495,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_RV_ZBA) + string(APPEND MARCH_STR "_zba") + endif() if (GGML_CPU_RISCV64_SPACEMIT) # `xsmtvdotii' is only required for GCC >= 15. if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND diff --git a/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake new file mode 100644 index 00000000000..c8a4d4b4ec9 --- /dev/null +++ b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake @@ -0,0 +1,32 @@ +include(CheckCSourceRuns) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)" AND GGML_CPU_RISCV64_SPACEMIT) + set(SMT_MARCH_STR "-march=rv64gcv_zfh_zvfh_zba_zicbop") + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND SMT_MARCH_STR "_xsmtvdotii") + endif() + set(CMAKE_REQUIRED_FLAGS "${SMT_MARCH_STR}") + + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vfwmadot v2, v0, v1, fp16\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFWMADOT_FP16) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot1 v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOTN) + check_c_source_compiles("int main() {__asm__ volatile(\"vpack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK) + check_c_source_compiles("int main() {__asm__ volatile(\"vnspack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + unset(CMAKE_REQUIRED_FLAGS) + + list(APPEND RISCV64_SPACEMIT_IME_SPEC "") + if (SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + set(RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME1") + endif() + + if (SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4 AND SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK AND SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + list(APPEND RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME2") + endif() + + message("RISCV64_SPACEMIT_IME_SPEC: ${RISCV64_SPACEMIT_IME_SPEC}") +endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b7acafdaa8..7b05edf6b75 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -50,6 +50,10 @@ #include "llamafile/sgemm.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + // Note: once we move threading into a separate C++ file // will use std::hardware_destructive_interference_size instead of hardcoding it here // and we'll use C++ attribute syntax. @@ -3011,7 +3015,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = tp->cgraph; const struct ggml_cplan * cplan = tp->cplan; +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(state->ith); +#else set_numa_thread_affinity(state->ith); +#endif struct ggml_compute_params params = { /*.ith =*/ state->ith, @@ -3068,6 +3076,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_barrier(state->threadpool); +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(state->ith); +#endif + return 0; } diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 91fe1925eaa..9563ea3e4bd 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -3,19 +3,32 @@ #include "ime.h" +#include "binary-ops.h" +#include "common.h" #include "ggml-backend-impl.h" #include "ggml-common.h" #include "ggml-cpu.h" +#include "ime_env.h" #include "ime_kernels.h" +#include "ops.h" +#include "repack.h" +#include "rvv_kernels.h" +#include "spine_mem_pool.h" #include "traits.h" +#include "vec.h" + +#include +#include +#include #include +#include #include +#include #include #include // for GGML_ASSERT #include #include - // clang-format off #if defined(__riscv) @@ -25,13 +38,17 @@ #include #endif -#if !defined(__riscv_zfh) -#error "riscv zfh extension not enabled" +#if !defined(__riscv_zfh) || !defined(__riscv_zvfh) +#error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1" #endif -#if defined(RISCV64_SPACEMIT_IME1) +#if !defined(__riscv_zba) +#error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2) #else -#error "RISCV64_SPACEMIT_IME1 not defined" +#error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined" #endif #else @@ -46,382 +63,490 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#if defined(RISCV64_SPACEMIT_IME1) -#define QGEMM_STRIDEN_THREAD_ALIGN 16 -#else -#define QGEMM_STRIDEN_THREAD_ALIGN 32 -#endif - // clang-format on -struct qnbitgemm_spacemit_ime_args { - const float * a_ptr = nullptr; - size_t lda = 0; - const std::byte * packed_quant_b_data = nullptr; - const float * quant_b_scale = nullptr; - const void * quant_b_zp = nullptr; - const float * quant_b_blksum = nullptr; - const float * bias = nullptr; - float * c_ptr = nullptr; - size_t ldc = 0; -}; - -constexpr size_t div_round_up(size_t up, size_t down) { - return (up + down - 1) / down; -} - -constexpr size_t q8_blk_size(size_t blk_len) { - const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(blk_size % alignof(float) == 0); - return blk_size; +extern "C" { +extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); +extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); } namespace ggml::cpu::riscv64_spacemit { -const int num_ai_cores = std::thread::hardware_concurrency() / 2; - -} // namespace ggml::cpu::riscv64_spacemit +struct TLSContext { + int cpu_id{ -1 }; + cpu_set_t cpuset; + void * tcm_buffer{ nullptr }; + size_t tcm_buffer_size{ 0 }; +}; -static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, - const size_t gemm_k, - const qnbitgemm_spacemit_ime_args * gemm_args, - void * const per_gemm_ws, - const size_t m_start, - const size_t m_count, - const size_t n_start, - const size_t n_count) { - constexpr size_t scale_stride = sizeof(uint16_t); - constexpr size_t blk_bitwidth = 4; +thread_local TLSContext tls_context; + +template constexpr size_t get_repacked_block_type_size() { + if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(block_q8_0); + } else if constexpr (std::is_same_v) { + return sizeof(block_q4_0) * INTER_SIZE / QK4_0; + } else if constexpr (std::is_same_v || std::is_same_v) { + return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1; + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q2_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q3_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_mxfp4<1>); + } else if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_1<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_0<1>); + } else { + assert(false); + return 0; + } +} - const size_t k_blks = div_round_up(gemm_k, blk_len); +template constexpr bool block_type_has_zp() { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return false; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return true; + } else { + assert(false); + return false; + } +} - const size_t lda = k_blks * q8_blk_size(blk_len); - const size_t ldc = gemm_args->ldc; - const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); - const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0; +}; - const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; - const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); - const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; +template class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + size = GGML_PAD(size, sizeof(int64_t)); - size_t count_n = 0; - const size_t compute_block_count_n = m_count == 1 ? n_count : 16; - for (size_t n = 0; n < n_count; n += count_n) { - count_n = std::min(n_count - n, compute_block_count_n); + return true; + } + case GGML_OP_MUL_MAT_ID: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - const std::byte * a_row = quant_a_ptr; - const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; - const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; - float * c_blk = c_ptr + n; + size = GGML_PAD(size, sizeof(int64_t)); - int32_t rows_remaining = m_count; + const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert + const int64_t ne12 = op->src[1]->ne[2]; // n_tokens - while (rows_remaining > 0) { - const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( - blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, - scale_stride); + const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t); - c_blk += rows_handled * ldc; - a_row += rows_handled * lda; + size = GGML_PAD(size, sizeof(int64_t)); - rows_remaining -= rows_handled; + return true; + } + default: + // GGML_ABORT("fatal error"); + break; } + return false; } -} -template constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT"); + return false; + } + break; + case GGML_OP_MUL_MAT_ID: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID"); + return false; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return -1; -} -template struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks -}; + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; -template struct block_with_zp { - ggml_half d[N]; // deltas for N qK_1 blocks - uint8_t zp[N]; // zero points for N qK_1 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks -}; + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; -// control size -static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); -static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), - "wrong block_with_zp<4,16> size/padding"); -static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + GGML_TENSOR_BINARY_OP_LOCALS -using block_q4_0x16 = block<4, 16>; -using block_q4_1x16 = block_with_zp<4, 16>; -using block_q8_0x16 = block<8, 16>; + int ith = params->ith; + int nth = params->nth; -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + [[maybe_unused]] const enum ggml_type type = src0->type; - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + const int64_t gemm_m = ne11 * ne12 * ne13; + const int64_t gemm_k = ne10; + const int64_t gemm_n = ne01; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::quantize_a_row_def quantize_a_4row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + bool set_kernel_impl = false; + + int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + set_kernel_impl = true; + } } - } +#endif - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); } - } - return out; -} + const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len); + const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len); -static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { - block_q4_1x16 out; - GGML_ASSERT(QK4_1 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); - float mid = -std::nearbyintf(m / d); - mid = std::min(15.0f, std::max(0.0f, mid)); - out.d[i] = GGML_FP32_TO_FP16(d); - out.zp[i] = static_cast(mid); - } + const int64_t row_stride_a = a_k_blks * block_stride_a; + const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t)); - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + if (ith == 0 && params->wsize < gemm_workspace_size) { + GGML_ABORT("wsize less than gemm_workspace_size"); } - } - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); - } - } + uintptr_t ws_ptr = reinterpret_cast(params->wdata); - return out; -} + void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 16); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - constexpr int nrows_interleaved = 16; + constexpr int64_t row_align = 4; + const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align); - block_q4_0x16 * dst = (block_q4_0x16 *) t->data; - const block_q4_0 * src = (const block_q4_0 *) data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const int64_t per_mb_rows_wsize = row_align * row_stride_a; + const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + const int64_t barrier_idx = static_cast(ith / 2); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { - return -1; - } + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + if (gemm_m == 1) { + int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth); + int a_blk_start = ith * task_per_thread; + int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks); + if (a_blk_start < a_blk_end) { + quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len, + quant_a_buffer + a_blk_start * block_stride_a); + } + } else { + int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth); + int m_row_blk_start = ith * task_per_thread; + int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks); + for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) { + int m_idx = m_row_blk * row_align; + int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx); + + if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; - } - return 0; - GGML_UNUSED(data_size); -} + ggml_barrier(params->threadpool); -static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_1); - GGML_ASSERT(interleave_block == 16); + const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16; + const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth); - constexpr int nrows_interleaved = 16; + int64_t gemm_n_stride = gemm_n; + if (max_gemm_n_stride < gemm_n) { + gemm_n_stride = + std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_1 * src = (const block_q4_1 *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_1; + if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) { + for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) { + uint8_t * b_col = reinterpret_cast(w_data); + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + int64_t m_row_real = std::min(gemm_m - m_start, row_align); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { - return -1; - } + spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a, + m_row_real * row_stride_a); - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS); + + uint8_t * a_row_ptr = (uint8_t *) tcm_buffer; + float * c_blk = output + m_start * gemm_n + ni; + + int32_t rows_remaining = m_row_real; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining, + n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_ptr += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; + } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) { + uint8_t * a_row = quant_a_buffer; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + gemm_workspace_size; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_UNUSED(data_size); -} + int64_t ni = ith * NB_COLS; + int64_t nb_real = std::min(gemm_n - ni, NB_COLS); -static inline void get_scale_min_k4(int j, - const uint8_t * GGML_RESTRICT q, - uint8_t * GGML_RESTRICT d, - uint8_t * GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; - } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); - } -} + if (ith % 2 == 0 && nb_real > 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + } -static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 16); - GGML_ASSERT(QK_K / QK4_1 == 8); + spine_barrier_wait(cur_barrier); - constexpr int nrows_interleaved = 16; + if (ith % 2 != 0 && nb_real > 0) { + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_K * src = (const block_q4_K *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + for (; ni < gemm_n; ni += NB_COLS * nth) { + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { - return -1; - } + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int j = 0; j < 8; j++) { - for (int i = 0; i < nrows_interleaved; i++) { - uint8_t sc, m; - const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - const float min = - GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); - get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); - // src -> [b0, b32] [b1, b33] ... [b31, b63] - // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] - const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; - if (j % 2 == 0) { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); - } - } else { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); - } - } + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS * nth; + if (next_ni < gemm_n) { + nb_real = std::min(gemm_n - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + next_ni * row_stride_b, + nb_real * row_stride_b); } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); } - } - src += nrows_interleaved * nblocks; - } - return 0; + } else { + const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride); - GGML_UNUSED(data_size); -} + int64_t task_count = task_count_m * task_count_n; + int64_t task_per_thread = (task_count + nth - 1) / nth; + int64_t start = ith * task_per_thread; + int64_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int64_t compute_idx = start; compute_idx < end; compute_idx++) { + const auto tid_n = compute_idx / task_count_m; + const auto tid_m = compute_idx % task_count_m; -namespace ggml::cpu::riscv64_spacemit { + const int64_t m_start = tid_m * gemm_m_stride; + const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride); -template -int repack(struct ggml_tensor *, const void *, size_t); + const int64_t n_start = tid_n * gemm_n_stride; + const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride); -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); -} + const int64_t n_blk = m_count == 1 ? n_count : NB_COLS; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); -} + uint8_t * b_col = reinterpret_cast(w_data) + n_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); -} + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(n_count - ni, n_blk); -class tensor_traits_base : public ggml::cpu::tensor_traits { - public: - virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; -}; + uint8_t * a_row = quant_a_buffer + m_start * row_stride_a; -template class tensor_traits : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; - size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } + float * c_blk = output + m_start * gemm_n + n_start + ni; - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_Q4_0 || // - op->src[0]->type == GGML_TYPE_Q4_1 || // - op->src[0]->type == GGML_TYPE_Q4_K) { - forward_mul_mat_q4(params, op); - return true; + int64_t rows_remaining = m_count; + + uint8_t * b_col_cur = b_col; + uint8_t * b_col_zp_cur = b_col_zp; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk, + rows_remaining, n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } } - default: - // GGML_ABORT("fatal error"); - break; + } } - return false; } - void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; + const ggml_tensor * src0 = op->src[0]; const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; ggml_tensor * dst = op; GGML_TENSOR_BINARY_OP_LOCALS @@ -429,133 +554,381 @@ template class tensor_ int ith = params->ith; int nth = params->nth; - [[maybe_unused]] const enum ggml_type type = src0->type; + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2; + bool set_kernel_impl = false; + size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5; + set_kernel_impl = true; + } + } +#endif - void * w_data = (void *) src0->data; - const float * feature = (const float *) src1->data; - float * output = (float *) dst->data; +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); + } - const size_t batch_feature = ne12 * ne13; - [[maybe_unused]] const size_t batch_weight = ne02 * ne03; - const size_t gemm_m = ne11; - const size_t gemm_k = ne10; - const size_t gemm_n = ne01; + const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len); + const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len); - GGML_ASSERT(batch_weight == 1); + const size_t nbw1 = a_k_blks * block_stride_a; + const size_t nbw2 = ne11 * nbw1; + const size_t nbw3 = nbw2 * ne12; + const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t)); - const size_t block_count_k = div_round_up(gemm_k, QK4_0); - const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); - const size_t per_gemm_workspace_stride = - div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); - const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; - const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - if (ith == 0 && params->wsize < desired_wsize) { - throw std::runtime_error("wsize less than desired_wsize"); + if (ne11 == 1) { + for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) { + int64_t i12 = ii / a_k_blks; + int64_t ak_blk_id = ii % a_k_blks; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len, + a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a); + } + } else { + for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) { + int64_t i12 = ii / ne11; + int64_t i11 = ii % ne11; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10, + quant_a_buffer + i12 * nbw2 + i11 * nbw1); + } } - std::vector qnbitgemm_args(batch_feature); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)] - for (size_t i = 0; i < batch_feature; i++) { - qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; - qnbitgemm_args[i].lda = gemm_k; - qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; - qnbitgemm_args[i].quant_b_scale = nullptr; + int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size); + int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as); + int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1); + int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1); + mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as); - if constexpr (std::is_same_v) { - qnbitgemm_args[i].quant_b_zp = nullptr; - } else { - qnbitgemm_args[i].quant_b_zp = w_data; + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } } - qnbitgemm_args[i].bias = nullptr; - qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; - qnbitgemm_args[i].ldc = gemm_n; + int32_t valid_ep_count_t = 0; + int32_t valid_act_count_t = 0; + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + valid_matrix_row_counts[valid_ep_count_t] = cur_a; + valid_act_count_t += cne1; + valid_ep_count_t += 1; + } + valid_ep_count[0] = valid_ep_count_t; + valid_act_count[0] = valid_act_count_t; } - const uintptr_t ws_ptr = reinterpret_cast(params->wdata); - void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); - const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + const int64_t barrier_idx = static_cast(ith / 2); - { - constexpr size_t block_size_m = 4; - size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); - int32_t task_count = batch_feature * per_gemm_block_count_m; - int32_t task_per_thread = (task_count + nth - 1) / nth; - int32_t start = ith * task_per_thread; - int32_t end = std::min((ith + 1) * task_per_thread, task_count); - for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / per_gemm_block_count_m; - int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; - int32_t m_idx = block_idx_in_gemm * block_size_m; - const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; - int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); - - if (rows_tobe_handled == block_size_m) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = - static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - } else { - while (rows_tobe_handled) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = static_cast(ws) + - gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - rows_tobe_handled -= 1; - m_idx += 1; + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; + + ggml_barrier(params->threadpool); + + const size_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const size_t expert_b_stride = ne01 * row_stride_b; + const size_t per_nb_cols_wsize = NB_COLS * row_stride_b; + + std::array src_workspaces; + std::array dst_workspaces; + + auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; + + const auto valid_ep_count_t = valid_ep_count[0]; + const auto valid_act_count_t = valid_act_count[0]; + + int nth_es = 1; + int nth_n = nth; + + int ith_es = ith % nth_es; + int ith_n = (ith / nth_es) % nth_n; + + if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as && + valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) { + for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride; + + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0); + const int id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + const int64_t i1 = id; + const int64_t i2 = i12; + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)); + + uint8_t * a_row = src1_col; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + nbw1; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; + + if (ith % 2 == 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + } + + spine_barrier_wait(cur_barrier); + + if (ith % 2 != 0) { + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + } + + int64_t nb_real = std::min(ne01, NB_COLS); + for (int64_t ni = 0; ni < ne01; ni += NB_COLS) { + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } + + gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01); + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS; + if (next_ni < ne01) { + nb_real = std::min(ne01 - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d( + b_col, reinterpret_cast(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize); } } } - } + } else { + for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + const int64_t cne1 = matrix_row_counts[cur_a]; - ggml_barrier(params->threadpool); + int64_t src1_cur_start = 0; + int64_t src1_cur_end = cne1; - if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { - return; - } - nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); - - size_t threads_per_gemm = nth / batch_feature; - constexpr size_t gemm_m_stride = 128; - size_t nc = gemm_n; - const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); - const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); - if (max_nc < nc) { - nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); - } - const size_t gemm_n_stride = nc; - const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); - const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); - threads_per_gemm = thread_count_m * thread_count_n; + int64_t src0_cur_start = (ith_n * ne01) / nth_n; + int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01); - { - int task_count = batch_feature * threads_per_gemm; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - const auto gemm_i = compute_idx / threads_per_gemm; - const auto blk_i = compute_idx % threads_per_gemm; - const auto * data = &qnbitgemm_args[gemm_i]; + if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) { + continue; + } + + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = + (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + + size_t extra_tcm_buffer_size = tcm_buffer_size; + void * extra_tcm_buffer = tcm_buffer; + if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 && + (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) { + spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur, + (src0_cur_end - src0_cur_start) * row_stride_b); + src0_cur = reinterpret_cast(tcm_buffer); + b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b; + extra_tcm_buffer = reinterpret_cast(reinterpret_cast(tcm_buffer) + + (src0_cur_end - src0_cur_start) * row_stride_b); + } - const auto tid_n = blk_i / thread_count_m; - const auto tid_m = blk_i % thread_count_m; + int ir1 = src1_cur_start; - const size_t m_start = tid_m * gemm_m_stride; - const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) { + int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1; + do { + quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1); - const size_t n_start = tid_n * gemm_n_stride; - const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + uint8_t * quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); - void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; + int iir1 = ir1; + for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1); - sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1); + quant_a_tile_buffer = quant_a_tile_buffer + nbw1; + } + + quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); + iir1 = ir1; + + if (moe_gemm_kernel_m2 != nullptr) { + for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1); + + src_workspaces[0] = quant_a_tile_buffer; + src_workspaces[1] = quant_a_tile_buffer + nbw1; + + dst_workspaces[0] = + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start; + dst_workspaces[1] = (float *) ((char *) dst->data + + ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) + + src0_cur_start; + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, + ne01); + } + } + + for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + + gemm_kernel( + b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start, + 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + + ir1 += quant_a_tile_size; + } while (ir1 < src1_cur_end); + } else { + if (moe_gemm_kernel_m2 != nullptr) { + for (; ir1 < src1_cur_end - 1; ir1 += 2) { + for (int iir1 = 0; iir1 < 2; ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + dst_workspaces[iir1] = + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start; + } + + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } + + for (; ir1 < src1_cur_end; ir1++) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1, + src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } } } +#undef MMID_MATRIX_ROW } - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + int repack(ggml_tensor * t, const void * data, size_t data_size) override { GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); @@ -563,309 +936,464 @@ template class tensor_ }; class tensor_traits_common : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override { switch (op->op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - size = 0; + case GGML_OP_FLASH_ATTN_EXT: + { + const int n_tasks = n_threads; + const int64_t neq2 = op->src[0]->ne[2]; // number of query heads + const int64_t DK = op->src[1]->ne[0]; + const int64_t DV = op->src[2]->ne[0]; // DV + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float) * + (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV + + GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) * + n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV)); + + size = MAX(prefill, decode); + } return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { switch (op->op) { case GGML_OP_NORM: - forward_norm_f32(params, op); - return true; + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } case GGML_OP_RMS_NORM: - forward_rms_norm_f32(params, op); + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_rms_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } + case GGML_OP_ADD: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_add(params, op); + return true; + } + case GGML_OP_SUB: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_sub(params, op); + return true; + } + case GGML_OP_MUL: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_mul(params, op); + return true; + } + case GGML_OP_DIV: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_div(params, op); + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + forward_flash_attn_ext_f16(params, op); + return true; + case GGML_OP_CONT: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] && + op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) { + spacemit_kernels::rvv::forward_cont_with_permute(params, op); + } else { + ggml_compute_forward_cont(params, op); + } + return true; + } + case GGML_OP_CPY: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] && + ggml_nelements(src0) == ggml_nelements(op)) { + spacemit_kernels::rvv::forward_cpy_with_permute(params, op); + } else { + ggml_compute_forward_cpy(params, op); + } + return true; + } + case GGML_OP_REPEAT: + { + const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op); + const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0]; + + if (rows_equal && broadcast_or_equal) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + default: + break; + } + } + + if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_repeat(params, op); + } + return true; + case GGML_OP_SUM_ROWS: + { + if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) { + spacemit_kernels::rvv::forward_sum_rows(params, op); + } else { + ggml_compute_forward_sum_rows(params, op); + } + } + return true; + case GGML_OP_GET_ROWS: + { + if (op->src[0]->type == op->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_get_rows(params, op); + } return true; + case GGML_OP_CONCAT: + { + const int32_t dim = ggml_get_op_params_i32(op, 0); + if (dim == 0 && op->type == op->src[0]->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_concat(params, op); + } + return true; + // TODO For GGML_OP_GATED_DELTA_NET + // case GGML_OP_GATED_DELTA_NET: + // return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + + const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT); + const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16); + const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128); + const bool supported_vlen = (__riscv_vlenb() == 128); + + if (!(supported_prec && supported_types && supported_shape && supported_vlen)) { + ggml_compute_forward_flash_attn_ext(params, dst); + return; + } + + // total rows in q + const int64_t nr = neq1 * neq2 * neq3; + // rows per thread const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ); - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); + // 4x chunks per thread + // int nth_scaled = nth * 4; + // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + // int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - GGML_ASSERT(epsilon > 0.0f); + // if (nth == 1 || nchunk < nth) { + // nchunk = nth; + // } - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; + int64_t nchunk = nth; - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, nth); + } - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); + ggml_barrier(params->threadpool); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - size_t gvl = __riscv_vsetvlmax_e32m4(); - vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + // The number of elements in each chunk + const int64_t dr = (nr + nchunk - 1) / nchunk; - sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); - p_input += gvl; - p_temp_output += gvl; - length -= gvl; + if (use_tiled) { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); + } else { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); } - gvl = __riscv_vsetvlmax_e32m1(); - - float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - vfloat32m1_t mean_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); - mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); - mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); - mean /= hidden_size; - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - mean_square = sqrt(mean_square - mean * mean + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } } - void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - - size_t gvl = __riscv_vsetvlmax_e32m4(); - // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + int repack(ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); +// Impl By IME1 +static const tensor_traits q4_0_16x32_q8_0; +static const tensor_traits q4_1_16x32_q8_0; +static const tensor_traits q4_k_16x32_q8_0; +// Impl By IME2 +static const tensor_traits q2_k_32x256_q8_0; +static const tensor_traits q3_k_32x256_q8_0; +static const tensor_traits q4_0_32x32_q8_0; +static const tensor_traits q4_1_32x32_q8_0; +static const tensor_traits q4_0_32x256_q8_0; +static const tensor_traits q4_1_32x256_q8_0; +static const tensor_traits q4_k_32x32_q8_0; +static const tensor_traits q6_k_32x32_q8_0; +static const tensor_traits q8_0_32x32_q8_0; +static const tensor_traits mxfp4_32x32_q8_0; +static const tensor_traits q5_k_32x32_q8_0; +static const tensor_traits q5_1_32x32_q8_0; +static const tensor_traits q5_0_32x32_q8_0; +// Impl By RVV +static const tensor_traits_common rvv_impl; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); +} // namespace ggml::cpu::riscv64_spacemit - p_input += gvl; - p_temp_output += gvl; - length -= gvl; +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) { + switch (cur->type) { + case GGML_TYPE_Q2_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0; + } +#endif } + break; + case GGML_TYPE_Q3_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0; + } - gvl = __riscv_vsetvlmax_e32m1(); - - // float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0; + } +#endif - mean_square = sqrt(mean_square + epsilon); +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0; + // } + + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0; + } +#endif - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0; + } +#endif - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0; } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q6_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0; } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q8_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0; } +#endif } - } - } - - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { - memcpy(t->data, data, data_size); - return 0; - } -}; - -static const tensor_traits q4_0_16x8_q8_0; -static const tensor_traits q4_1_16x8_q8_0; -static const tensor_traits q4_k_16x8_q8_0; -static const tensor_traits_common rvv_impl; - -} // namespace ggml::cpu::riscv64_spacemit - -static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_1) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_F32) { - return &ggml::cpu::riscv64_spacemit::rvv_impl; + break; + case GGML_TYPE_MXFP4: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0; + // } +#endif + } + break; + case GGML_TYPE_Q5_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0; + } +#endif + } + break; + default: + break; } return nullptr; } static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor) { + ggml_tensor * tensor) { tensor->extra = (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); @@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba return GGML_STATUS_SUCCESS; } +static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + if (base == nullptr) { + return; + } + + ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base); +} + +static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + return base; +} + +static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + GGML_ASSERT(tensor); + memset((char *) tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + memset(base, value, buffer->size); +} + static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor, + ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_ GGML_UNUSED(buffer); } +static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = { + /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer, + /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base, + /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear, + /* .reset = */ nullptr, +}; + static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "CPU_RISCV64_SPACEMIT"; @@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_ static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { + void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64); + if (base == nullptr) { return nullptr; } - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; + return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size); } static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, - const struct ggml_tensor * tensor) { +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { if (tensor->ne[i] <= 0) { return 0; } } - size_t nbytes; + GGML_UNUSED(buft); + + const auto plain_nbytes = [&]() { + size_t total = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * tensor->nb[i]; + } + return total; + }; + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + return plain_nbytes(); + } + + const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + + const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size; } - } else { - nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; - if (tensor->type == GGML_TYPE_Q4_K) { - GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); - nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - } - } else { - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } + return total; + }; + + const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) { + GGML_ASSERT(row_nbytes % src_block_size == 0); + + size_t total = + add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size); + + if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) { + total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size; } + + return total; + }; + + size_t nbytes = row_nbytes; + switch (tensor->type) { + case GGML_TYPE_Q4_K: + nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8); + break; + case GGML_TYPE_Q6_K: + nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32); + break; + case GGML_TYPE_Q8_0: + nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32); + break; + case GGML_TYPE_Q2_K: + nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>)); + break; + case GGML_TYPE_Q3_K: + nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>)); + break; + case GGML_TYPE_MXFP4: + nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>)); + break; + case GGML_TYPE_Q5_K: + nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8); + break; + case GGML_TYPE_Q5_1: + nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>)); + break; + case GGML_TYPE_Q5_0: + nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>)); + break; + default: + nbytes = add_strided_nbytes(row_nbytes, 1, 1); + break; } - GGML_UNUSED(buft); return nbytes; } namespace ggml::cpu::riscv64_spacemit { class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && @@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } } break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - if (op->src[0]->type == GGML_TYPE_F32) { - return true; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } } break; default: @@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_REPEAT: + case GGML_OP_SUM_ROWS: + case GGML_OP_GET_ROWS: + case GGML_OP_CONCAT: + // case GGML_OP_GATED_DELTA_NET: return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); default: // GGML_ABORT("fatal error"); @@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } // namespace ggml::cpu::riscv64_spacemit ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { /* .iface = */ { /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, @@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { return &ggml_backend_cpu_buffer_type_riscv64_spacemit; } + +extern "C" { +static int bind_ai_thread() { + int fd, bytes; + char str[32]; + + fd = open("/proc/set_ai_thread", O_WRONLY); + if (fd < 0) { + GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n"); + return -1; + } + + snprintf(str, 16, "%d", 0); + bytes = write(fd, str, strlen(str)); + if (bytes < 0) { + GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n"); + close(fd); + return -1; + } + + close(fd); + return 0; +} + +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) { + int cpu_id = sched_getcpu(); + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 && + !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) { + GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid()); + bind_ai_thread(); + } + + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm && + ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) { + CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + pthread_t main_thread = pthread_self(); + const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids; + if (thread_n < 0 || static_cast(thread_n) >= perfer_core_ids.size()) { + GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size()); + } + auto perfer_cpu_id = perfer_core_ids[static_cast(thread_n)]; + CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + int s = + pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + if (s != 0) { + GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id); + } + + int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset; + ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id; + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id); + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size = + ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size; + } + + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + void * rt = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt == nullptr) { + GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) { + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release( + ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt != 0) { + GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index 800d91acdae..6849dd95e05 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -8,6 +8,14 @@ extern "C" { ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n); + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n); + +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment); + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index cbbb6cd9160..6acc6819dfb 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,8 +1,26 @@ +#include "ggml-impl.h" #include "ggml.h" #include "ime_kernels.h" +#include "rvv_kernels.h" #include #include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +# error "RISCV64_SPACEMIT_IME1 not defined" +#endif // clang-format off #if defined(__GNUC__) @@ -11,7 +29,7 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif // clang-format on -namespace sqnbitgemm_spacemit_ime { +namespace spacemit_kernels { #define QUANTIZEM4ROW_KERNEL \ "vmv.s.x v16, zero \n\t" \ @@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime { "vse8.v v31, (s1) \n\t" namespace ime1 { -void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + uint8_t * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, %[BlkLen], e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "sub t2, t2, t0 \n\t" - "slli t1, t0, 2 \n\t" - "add %[SRC], %[SRC], t1 \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE - - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL - - "addi t3, %[BlkLen], 0 \n\t" - "addi s2, s1, 0 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "SET_ZERO%=: \n\t" - "vse8.v v8, (s2) \n\t" - "addi s2, s2, 32 \n\t" - "addi t3, t3, -8 \n\t" - "bnez t3, SET_ZERO%= \n\t" - - QUANTIZEM4ROW_STORE - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 128) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "addi t2, t2, -128 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v16, v24, v16 \n\t" - "vfredmax.vs v24, v16, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e64, m4 \n\t" - "vsse64.v v16, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "sub t2, t2, t2 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "jal x0, QUANTIZE%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 256) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "addi t2, t2, -256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v24, v16 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t1, t2, 0 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - - "TAIL_LOOP%=: \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t2, e32, m1 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 32 \n\t" - "vfmul.vf v1, v0, f11 \n\t" - "vfcvt.x.f.v v2, v1 \n\t" - "vsetvli t0, zero, e16, mf2 \n\t" - "vnclip.wx v3, v2, zero \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vnclip.wx v3, v3, zero \n\t" - "vse8.v v3, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bnez t2, TAIL_LOOP%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK), + [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); } } -void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { const float * SRC = A; - std::byte * DST = QuantA; + uint8_t * DST = QuantA; constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; - if (CountK <= BlkLen) { - float max_abs_A = 0.0f; - for (size_t k = 0; k < CountK; k++) { - max_abs_A = std::max(max_abs_A, fabsf(A[k])); - } - float scale_A = max_abs_A * range_max_reciprocal; - - ((float *) QuantA)[0] = scale_A; - - auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); - - for (size_t k = 0; k < CountK; k++) { - QuantAData_offset[k] = - (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), - (float) std::numeric_limits::max()); - } - for (size_t k = CountK; k < BlkLen; k++) { - QuantAData_offset[k] = 0; - } - - return; - } - - if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { - __asm__ volatile( - "vsetvli t0, zero, e8, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "LOOP%=: \n\t" - "vsetvli t0, %[CNT], e8, m8 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "sub %[CNT], %[CNT], t0 \n\t" - "bnez %[CNT], LOOP%= \n\t" - : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) - : - : "cc", "t0"); - } - if (BlkLen == 16) { - float buffer[64] = { 0.0f }; - __asm__ volatile( - "addi t3, zero, 16*8 \n\t" - "addi t2, zero, 16 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m2 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v2, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v4, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v6, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v10, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v12, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v14, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v18, v2 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v22, v6 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v26, v10 \n\t" - "vfabs.v v28, v12 \n\t" - "vfabs.v v30, v14 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v18, v18, v19 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v22, v22, v23 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v26, v26, v27 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vse32.v v16, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v18, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v20, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v22, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v24, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v26, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v28, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v30, (a1) \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f11, f3, f7 \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f11, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f12, f3, f7 \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fsw f12, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f13, f3, f7 \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f13, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f14, f3, f7 \n\t" - "fmul.s f14, f14, %[RMAXREC] \n\t" - "fsw f14, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f14, %[FONE], f14 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f15, f3, f7 \n\t" - "fmul.s f15, f15, %[RMAXREC] \n\t" - "fsw f15, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f15, %[FONE], f15 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f16, f3, f7 \n\t" - "fmul.s f16, f16, %[RMAXREC] \n\t" - "fsw f16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f16, %[FONE], f16 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f17, f3, f7 \n\t" - "fmul.s f17, f17, %[RMAXREC] \n\t" - "fsw f17, (%[DST]) \n\t" - "addi %[DST], %[DST], -136 \n\t" - "fdiv.s f17, %[FONE], f17 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v18, v2, f11 \n\t" - "vfmul.vf v20, v4, f12 \n\t" - "vfmul.vf v22, v6, f13 \n\t" - "vfmul.vf v24, v8, f14 \n\t" - "vfmul.vf v26, v10, f15 \n\t" - "vfmul.vf v28, v12, f16 \n\t" - "vfmul.vf v30, v14, f17 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v18, v18 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v22, v22 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v26, v26 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vfcvt.x.f.v v30, v30 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v18, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v20, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v22, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v26, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v28, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v30, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vse32.v v16, (%[BUFFER]) \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m2 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) - : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", - "f13", "f14", "f15", "f16", "f17"); - } else if (BlkLen == 32) { - __asm__ volatile( - "addi t3, zero, 32*4 \n\t" - "addi t2, zero, 32 \n\t" - - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 128 \n\t" - "addi a3, %[SRC], 256 \n\t" - "addi a4, %[SRC], 384 \n\t" - - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 36 \n\t" - "addi s3, %[DST], 72 \n\t" - "addi s4, %[DST], 108 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v4, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vle32.v v8, (a3) \n\t" - "addi a3, a3, 512 \n\t" - "vle32.v v12, (a4) \n\t" - "addi a4, a4, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v28, v12 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v20, v20, v22 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vfmax.vv v28, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v21, v20, v21 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfredmax.vs v29, v28, v29 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v21 \n\t" - "vfmv.f.s f12, v25 \n\t" - "vfmv.f.s f13, v29 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fsw f12, (s3) \n\t" - "addi s3, s3, 4 \n\t" - "fsw f13, (s4) \n\t" - "addi s4, s4, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v20, v4, f11 \n\t" - "vfmul.vf v24, v8, f12 \n\t" - "vfmul.vf v28, v12, f13 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vsetvli t0, t1, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 140 \n\t" - "vse8.v v20, (s2) \n\t" - "addi s2, s2, 140 \n\t" - "vse8.v v24, (s3) \n\t" - "addi s3, s3, 140 \n\t" - "vse8.v v28, (s4) \n\t" - "addi s4, s4, 140 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m4 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 128 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); - } else if (BlkLen == 64) { - __asm__ volatile( - "addi t3, zero, 64*2 \n\t" - "addi t2, zero, 64 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 68 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vfmax.vv v24, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v25 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 132 \n\t" - "vse8.v v24, (s2) \n\t" - "addi s2, s2, 132 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 256 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); - } else if (BlkLen == 128) { - __asm__ volatile( - "addi t2, zero, 128 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "sub %[K], %[K], t2 \n\t" - "QUANT%=: \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "vfmax.vv v28, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v30, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vfredmax.vs v31, v30, v31 \n\t" - "vfmv.f.s f10, v31 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v8, (a2) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "jal x0, QUANT%= \n\t" - "END%=: \n\t" - - : [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) - : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); - } else { - float buffer[8] = { 0.0f }; - size_t cnt = BlkLen / 256; - - __asm__ volatile( - "slli t3, %[BLK], 2 \n\t" - "blt %[K], %[BLK], LOOP_TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_CMP%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v16, v16, v24 \n\t" - "vfmax.vv v0, v0, v16 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, LOOP_CMP%= \n\t" - "sub %[SRC], %[SRC], t3 \n\t" - "addi t6, %[CNT], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_QUANT%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "vse8.v v4, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bnez t6, LOOP_QUANT%= \n\t" - "sub %[K], %[K], %[BLK] \n\t" - "bge %[K], %[BLK], LOOP_MAIN%= \n\t" - "blez %[K], END%= \n\t" - "LOOP_TAIL%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[K], 0 \n\t" - "addi s1, %[SRC], 0 \n\t" - "TAIL_CMP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t6, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "sub t6, t6, t0 \n\t" - "vfabs.v v0, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, TAIL_CMP%= \n\t" - "addi t6, %[K], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[K], 0 \n\t" - "TAIL_QUANT%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t1, t6, e32, m8 \n\t" - "vle32.v v0, (s1) \n\t" - "addi s1, s1, 256 \n\t" - "sub t6, t6, t1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 64 \n\t" - "bnez t6, TAIL_QUANT%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), - [CNT] "r"(cnt) - : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); - } + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); } } // namespace ime1 @@ -1451,1746 +584,444 @@ namespace { "vadd.vi v1, v1, -12 \n\t" template -void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } -} -template -void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); - const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; - - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } } template -void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { + GGML_UNUSED(ldc); size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - // zp offset - "addi s7, %[B], 32 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s7, %[B], 32 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} -template -void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - const size_t INNER = BlkLen / 16; - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - // zp offset - "addi s7, %[B], 64 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - - // load scale - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - // load a scale - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - // a scale * b scale - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s7, %[B], 64 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} - -template -inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias, ldc); - - } else if (scalestride == 2) { - SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( - BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); - } -} -template -inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias); - } else if (scalestride == 2) { - SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } } } - } // namespace namespace ime1 { -size_t gemm_kernel_i8i4(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float * Bias, - const size_t ScaleStride) { - GGML_UNUSED(CountM); - GGML_UNUSED(CountK); - GGML_UNUSED(ldc); - if (CountM >= 4) { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { + if (quant_b_zp != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 4; } else { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + if (quant_b_zp != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 1; } } } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp new file mode 100644 index 00000000000..0c7a036a92a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp @@ -0,0 +1,5768 @@ +#include "ggml-impl.h" +#include "ggml.h" +#include "ime_kernels.h" +#include "rvv_kernels.h" +#include "string.h" + +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME2) +#else +# error "RISCV64_SPACEMIT_IME2 not defined" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels { +namespace ime2 { + +template +void gemm_kernel_i8i2k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + uint8_t * zeros16 = (uint8_t *) (quant_b_blk_data->zeros16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + _Float16 * zeros_fp16 = (_Float16 *) zeros16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + uint8_t * scales_temp = scales; + uint8_t * zps_temp = scales; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, zps_temp++) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + int16_t a_sum = a_sum_row[mi * 16 + kii]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0.0; + + uint8_t b_zp = zps_temp[ci * 16] >> 4; + uint8_t b_scale = scales_temp[ci] & 0x0F; + for (size_t bi = 0; bi < 16; bi++) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + uint8_t b0 = b_data_col[ci * 16 + bi]; + acc_0 += static_cast(a0) * static_cast((b0 >> b_shift) & 0x03); + } + + _Float16 scale_item = + static_cast<_Float16>(b_scale) * static_cast<_Float16>(factor_scale) * scales_fp16[ci]; + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + output[ci + mi * NB_COLS] += b_zp * a_sum * a_scale_row[mi] * zeros_fp16[ci]; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i3k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * b_hmask = quant_b_blk_data->hmask; + int8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + int8_t * scales_temp = scales; + uint16_t * b_mask_col = (uint16_t *) b_hmask; + + float acc_0_max = 0.0f; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, b_mask_col += NB_COLS) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0; + // blk 2 * kii + 0 + uint16_t b_shift_mask = 1; + for (size_t bi = 0; bi < 16; bi++, b_shift_mask <<= 1) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + int8_t b0 = static_cast((b_data_col[ci * 16 + bi] >> b_shift) & 0x03); + b0 -= b_mask_col[ci] & b_shift_mask ? 0 : 4; + acc_0 += static_cast(a0) * static_cast(b0); + } + + _Float16 scale_item = static_cast<_Float16>(scales_temp[ci]) * scales_fp16[ci] * + static_cast<_Float16>(factor_scale); + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t kblks_per_blk = 16; + GGML_ASSERT(k_blks % kblks_per_blk == 0); + + int64_t b_blk_stride = (sizeof(_Float16) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(_Float16); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + output_f16[ci + mi * NB_COLS] = static_cast<_Float16>(0.0f); + } + } + + size_t kii = 0; + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + _Float16 * b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + _Float16 a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_scale_fp16[ci]; + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output_f16[ci + mi * NB_COLS] += + static_cast(acc) * static_cast(a_scale) * static_cast(b_scale); + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } else { + kii++; + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_hp_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * MB_ROWS; + const size_t a_subblk_stride = q8_hp_blk_size(32, false, false) * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const uint8_t * b_tile_base = quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride) { + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + const uint8_t * b_superblk_ptr = b_tile_base + ki * b_superblk_stride; + const block_q4_0x32_layout * b_blocks = reinterpret_cast(b_superblk_ptr); + const uint8_t * b_zps = + quant_b_zp ? b_superblk_ptr + sizeof(block_q4_0x32_layout) * k_subblks_per_superblk : nullptr; + + _Float16 * a_sum_row = (_Float16 *) (a_data + a_subblk_stride * k_subblks_per_superblk); + _Float16 * a_scale_avg_row = (_Float16 *) (a_data + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + _Float16 scale_factor = a_scale_avg_row[0]; + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + int8_t * a_subblk = a_data + a_subblk_stride * ksi + MB_ROWS * sizeof(_Float16); + const _Float16 a_scale = a_scale_row[0]; + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + const uint8_t * b_qs = b_block.qs + ci * 16; + _Float16 b_scale = b_block.d[ci] * a_scale; + + int16_t acc = 0; + for (size_t bi = 0; bi < 16; bi++) { + uint8_t b = b_qs[bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + + acc += static_cast(a_subblk[mi * 32 + 2 * bi]) * static_cast(b0) + + static_cast(a_subblk[mi * 32 + 2 * bi + 1]) * static_cast(b1); + } + + const _Float16 scaled_acc = static_cast<_Float16>(acc) * b_scale; + output_f16[ci + mi * NB_COLS] += scaled_acc; + } + } + } + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + const uint8_t * b_zp_row = b_zps ? b_zps + ksi * NB_COLS : nullptr; + const _Float16 a_scale = a_scale_row[0]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + const _Float16 a_sum = a_sum_row[mi * k_subblks_per_superblk + ksi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_block.d[ci] * a_scale; + _Float16 a_sum_bzp = a_sum; + if (b_zp_row) { + a_sum_bzp = a_sum * static_cast<_Float16>(0.125f) * static_cast<_Float16>(b_zp_row[ci]); + } + + const _Float16 scaled_acc = a_sum_bzp * b_scale; + output[ci + mi * NB_COLS] += scaled_acc * scale_factor; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + auto val = static_cast(output_f16[ci + mi * NB_COLS]) * static_cast(scale_factor); + output[ci + mi * NB_COLS] += val; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void moe_gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] = c_ptr[mi]; + } + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = (a_data[mi])[2 * bi]; + int8_t a1 = (a_data[mi])[2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + (c_data[mi])[ci] = output[mi * NB_COLS + ci]; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] += NB_COLS; + } + } +} + +template +void moe_gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + // blk_len is expected to be 32 for Q5 types. + int64_t a_blk_stride = q8_blk_size(blk_len, true); + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + c_data[mi] = c_ptr[mi]; + } + + if (quant_b_zp) { + using blk_type = nrow_block_q5_1; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } else { + using blk_type = nrow_block_q5_0; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } +} + +template +void gemm_kernel_i8i8_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + blk_len); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + int8_t * b_data = (int8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len; bi++) { + int8_t a0 = a_data[mi * blk_len + bi]; + int8_t b0 = b_data[ci * blk_len + bi]; + acc += static_cast(a0) * static_cast(b0); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 for Q5 types + // quant_b_zp != nullptr => nrow_block_q5_1 (has zp) + // quant_b_zp == nullptr => nrow_block_q5_0 (no zp) + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + if (quant_b_zp) { + // nrow_block_q5_1: scales16[NB_COLS] + zp[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_1; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + // qh is packed as 4 bytes per column (32 bits for 32 elements) + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } else { + // nrow_block_q5_0: scales16[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_0; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + // Q5_0 has no zp, use default offset 16 (midpoint of 5-bit unsigned range) + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } +} + +template +void gemm_kernel_i8mxfp4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 (QK_MXFP4) + // quant_b_zp is unused for MXFP4 (symmetric quantization) + GGML_UNUSED(quant_b_zp); + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + using blk_type = nrow_block_mxfp4; + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = GGML_E8M0_TO_FP32_HALF(quant_b_blk_data->e[ci]); + + // Read 32 sign bits for this column + uint32_t sign_bits; + memcpy(&sign_bits, &quant_b_blk_data->qh[ci * 4], 4); + + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + + // qs[ci*16 + bi] stores abs(vals[bi*2]) in low 4 bits + // and abs(vals[bi*2+1]) in high 4 bits + uint8_t qs_byte = quant_b_blk_data->qs[ci * 16 + bi]; + int8_t b_abs0 = static_cast(qs_byte & 0x0F); + int8_t b_abs1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract sign bits: bit (2*bi) for vals[2*bi], bit (2*bi+1) for vals[2*bi+1] + int8_t b0 = (sign_bits >> (2 * bi)) & 1 ? -b_abs0 : b_abs0; + int8_t b1 = (sign_bits >> (2 * bi + 1)) & 1 ? -b_abs1 : b_abs1; + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +void gemm_kernel_i8i2k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "addi %[A], %[A], 4 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 32 \n\t" // A data addr + "addi s3, %[B], 0 \n\t" + "vxor.vv v30, v29, v29 \n\t" // tmp result + + "INNER_K_LOOP%=: \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vxor.vv v2, v2, v2 \n\t" + "vxor.vv v3, v3, v3 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + + // A data, 1x64@i8 + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v4, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v5, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v6, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetvli t0, x0, e64, mf2 \n\t" + "vslideup.vi v3, v4, 2 \n\t" + "vslideup.vi v28, v5, 4 \n\t" + "vslideup.vi v29, v6, 6 \n\t" + + // init the accumu to zero + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v18, v18 \n\t" + "vxor.vv v22, v18, v18 \n\t" + "vxor.vv v24, v18, v18 \n\t" + "vxor.vv v26, v18, v18 \n\t" + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu v20, v2, v8, i8 \n\t" + "vmadotsu v22, v2, v12, i8 \n\t" + "vmadotsu v24, v2, v16, i8 \n\t" + "vmadotsu v26, v2, v4, i8 \n\t" + + "vmadotsu v20, v3, v9, i8 \n\t" + "vmadotsu v22, v3, v13, i8 \n\t" + "vmadotsu v24, v3, v17, i8 \n\t" + "vmadotsu v26, v3, v5, i8 \n\t" + + "vmadotsu v20, v28, v10, i8 \n\t" + "vmadotsu v22, v28, v14, i8 \n\t" + "vmadotsu v24, v28, v18, i8 \n\t" + "vmadotsu v26, v28, v6, i8 \n\t" + + "vmadotsu v20, v29, v11, i8 \n\t" + "vmadotsu v22, v29, v15, i8 \n\t" + "vmadotsu v24, v29, v19, i8 \n\t" + "vmadotsu v26, v29, v7, i8 \n\t" + + "vand.vi v10, v0, 0xf \n\t" // scale + "vwadd.vx v12, v10, x0 \n\t" + "vsetvli t0, x0, e16, m2 \n\t" + "vwadd.vx v16, v12, x0 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v6, v2, v4, 3 \n\t" // 0,1 + "vpack.vv v8, v3, v5, 3 \n\t" // 2,3 + + // mul scale + "vmacc.vv v30, v6, v16 \n\t" + "vmacc.vv v30, v7, v17 \n\t" + "vmacc.vv v30, v8, v18 \n\t" + "vmacc.vv v30, v9, v19 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + // load zp B + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v4, (s3) \n\t" + "vsrl.vi v8, v4, 4 \n\t" // zp + + // asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + + "vsetvli t0, x0, e16, mf4 \n\t" + "vle16.v v2, (%[A]) \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vnsrl.wi v12, v2, 0 \n\t" // low 8 + "vnsra.wi v13, v2, 8 \n\t" // high 8 + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v20, v13, v8, i8 \n\t" + "vmadotsu v22, v13, v9, i8 \n\t" + "vmadotsu v24, v13, v10, i8 \n\t" + "vmadotsu v26, v13, v11, i8 \n\t" + + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vsll.vi v26, v26, 8 \n\t" + + "vmadotu v20, v12, v8, i8 \n\t" + "vmadotu v22, v12, v9, i8 \n\t" + "vmadotu v24, v12, v10, i8 \n\t" + "vmadotu v26, v12, v11, i8 \n\t" + + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v28, v2, v4, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v0, (t2) \n\t" // scale16 + "addi t2, t2, 64 \n\t" + "vle16.v v1, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v2, v0 \n\t" + "vfwcvt.f.f.v v4, v1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v30, v30 \n\t" + "vfcvt.f.x.v v28, v28 \n\t" + "addi %[B], t2, 64 \n\t" + "mv %[A], t3 \n\t" + + "vfmul.vv v30, v30, v2 \n\t" // mul scale16 + "vfmacc.vv v30, v28, v4 \n\t" // + mul zero16 + "vfmacc.vf v31, fa0, v30 \n\t" + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3"); + } +} + +void gemm_kernel_i8i2k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + _Float16 scale = 0.0625f; + _Float16 scale_1 = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v31, v31 \n\t" // init result + "vxor.vv v29, v31, v31 \n\t" + "vxor.vv v30, v31, v31 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 128 \n\t" // A data addr + "addi s4, t2, 1024 \n\t" // scale16 addr + "addi s4, s4, 1024 \n\t" // TODO + "addi s3, %[B], 0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s4) \n\t" // load scale16 + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v22, v1, v1, 3 \n\t" + + "addi s4, t3, 256 \n\t" // addr 1 + "addi s5, t3, 512 \n\t" // addr 2 + "addi s6, t3, 768 \n\t" // addr 3 + + // init the accu to 0 + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v25, v25, v25 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v27, v27, v27 \n\t" + + "INNER_K_LOOP%=: \n\t" + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + "vand.vi v1, v1, 0xf \n\t" + + "vfwcvt.f.x.v v20, v1 \n\t" // f16 scale B + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v0, v20, v22 \n\t" // mul scale16 + "vfmul.vv v1, v21, v22 \n\t" // mul scale16 + "vfmul.vf v0, v0, %[SCALE] \n\t" // mul magic + "vfmul.vf v1, v1, %[SCALE] \n\t" // mul magic + + // A data, 4x64@i8 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 64 \n\t" + "vle8.v v3, (s4) \n\t" + "addi s4, s4, 64 \n\t" + "vle8.v v4, (s5) \n\t" + "addi s5, s5, 64 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 64 \n\t" + + // 4x64 => 4x16x4 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v6, v2, v3, 1 \n\t" + "vpack.vv v8, v4, v5, 1 \n\t" + "vpack.vv v2, v6, v8, 2 \n\t" // 0, 2 + + "vpack.vv v20, v2, v2, 3 \n\t" // 1 + "vor.vv v23, v21, v21 \n\t" + "vpack.vv v20, v3, v3, 3 \n\t" // 3 + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu.hp v24, v2, v8, v0, 0, i8 \n\t" + "vmadotsu.hp v25, v2, v12, v0, 1, i8 \n\t" + "vmadotsu.hp v26, v2, v16, v0, 2, i8 \n\t" + "vmadotsu.hp v27, v2, v4, v0, 3, i8 \n\t" + + "vmadotsu.hp v24, v23, v9, v0, 4, i8 \n\t" + "vmadotsu.hp v25, v23, v13, v0, 5, i8\n\t" + "vmadotsu.hp v26, v23, v17, v0, 6, i8\n\t" + "vmadotsu.hp v27, v23, v5, v0, 7, i8 \n\t" + + "vmadotsu.hp v24, v3, v10, v1, 0, i8 \n\t" + "vmadotsu.hp v25, v3, v14, v1, 1, i8 \n\t" + "vmadotsu.hp v26, v3, v18, v1, 2, i8 \n\t" + "vmadotsu.hp v27, v3, v6, v1, 3, i8 \n\t" + + "vmadotsu.hp v24, v21, v11, v1, 4, i8\n\t" + "vmadotsu.hp v25, v21, v15, v1, 5, i8\n\t" + "vmadotsu.hp v26, v21, v19, v1, 6, i8\n\t" + "vmadotsu.hp v27, v21, v7, v1, 7, i8 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v2, v24, v25, 1 \n\t" + "vpack.vv v4, v26, v27, 1 \n\t" + "vpack.vv v6, v2, v4, 2 \n\t" // 0,1,2,3 + + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + // load zp B, 16x8x4@int4 + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v0, (s3) \n\t" + "vsrl.vi v0, v0, 4 \n\t" // zp + + // 4x16@int16 + "vsetvli t0, x0, e16, m1 \n\t" // a sum + "vle16.v v12, (%[A]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnsrl.wi v10, v12, 0 \n\t" // low 8 + "vnsra.wi v11, v12, 8 \n\t" // high 8 + + // asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v18, v11, v0, i8 \n\t" + "vmadotsu v20, v11, v1, i8 \n\t" + "vmadotsu v22, v11, v2, i8 \n\t" + "vmadotsu v24, v11, v3, i8 \n\t" + "vsll.vi v18, v18, 8 \n\t" + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vmadotu v18, v10, v0, i8 \n\t" + "vmadotu v20, v10, v1, i8 \n\t" + "vmadotu v22, v10, v2, i8 \n\t" + "vmadotu v24, v10, v3, i8 \n\t" + + "vpack.vv v10, v18, v20, 2 \n\t" + "vpack.vv v12, v22, v24, 2 \n\t" + "vpack.vv v14, v10, v12, 3 \n\t" + "vpack.vv v16, v11, v13, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t2, t2, 64 \n\t" + "vle16.v v20, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v22, v20 \n\t" + + // mul 1/magic + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmul.vf v0, v6, %[SCALE_1] \n\t" + "vfwmul.vf v2, v7, %[SCALE_1] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v14, v14 \n\t" + "vfcvt.f.x.v v15, v15 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + + "addi %[B], t2, 64 \n\t" + "mv %[A], s6 \n\t" + + "vfmacc.vv v0, v14, v22 \n\t" // + mul zero16 + "vfmacc.vv v1, v15, v22 \n\t" + "vfmacc.vv v2, v16, v22 \n\t" + "vfmacc.vv v3, v17, v22 \n\t" + + "vfmacc.vf v28, fa0, v0 \n\t" // mul a scale + "vfmacc.vf v29, fa1, v1 \n\t" + "vfmacc.vf v30, fa2, v2 \n\t" + "vfmacc.vf v31, fa3, v3 \n\t" + + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "vse32.v v29, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v30, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v31, (t1) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks), [LDC] "r"(ldc * 4), [SCALE] "f"(scale), [SCALE_1] "f"(scale_1) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i3k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; //only support 32 in ASM + using blk_type = nrow_block_q3_k; + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride; + int64_t b_ncol_block_stride = sizeof(blk_type); + + // Constants used by q3_k scaling in HP branch: + // - k_q3k_scale_step: per-nibble scale factor (1/16). + // - k_a_scale_post_mul: A_scale needs an extra *16 at the end (pairs with 1/16 above). + const _Float16 k_q3k_scale_step = (_Float16) 0.0625f; // 1 / 16 + const float k_a_scale_post_mul = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; +#if 0 + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1 32bit + // Asum int16 * 16 256bit + // A M1K256 int8 2048bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+32 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" // clear acc + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + + // ordinary vmadot: vle*10 vecIns*78 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // K0-15 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load B qs chunk (128B per K16, 16 times => 2048B) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (s3) \n\t" + "addi s3, s3, 64 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K16-31 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // load A scale (fp32) and advance A to next superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // t4 = next B blk base + "addi s3, s2, 4+32 \n\t" + + // load B scales16[32] (fp16) at end of qs region + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v2, (s6) \n\t" + + // pointer modify + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v30 \n\t" + "vfmul.vf v1, v24, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v31, v1, v26 \n\t" + + // next K-superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "s2", "s3", "s4", "s5", "s6", "s7"); +#else + + __asm__ volatile( + // ========================= + // Kernel overview (M1 x N32) + // ========================= + // Process one output row (M=1) and 32 columns (N=32) per call. + // + // Loop structure: + // - Outer loop: K superblocks of size K=256 (k_blks times) + // - Each K256 superblock is broken into 4 x K64 + // - Each K64 is processed as 4 x K16 "sub-blocks" (via unpack+dot) + // + // Data layout (high level): + // A (q8k K=256, per superblock): + // [ fp32 a_scale ][ int16 a_sum[16] ][ int8 a_qs[256] ] + // B (nrow_block_q3_k<32>, per superblock): + // [ int8 scales[32*16] ][ hmask[1024] ][ qs[2048] ][ fp16 scales16[32] ] + // + // Registers/pointers: + // s2: pA (points at A superblock header; used to load fp32 a_scale) + // s3: pA_qs (points at A int8 data within the current superblock) + // s4: pB_scales (points at B int8 per-K16 scales) + // s5: pB_hmask (points at B sign mask area) + // s6: pB_qs (points at B 2-bit packed qs area) + // s8: pB_scales16 (points at B fp16 scales16[32] at the end of block) + // s7: pB_base (base pointer to current B block; used for block-to-block stride) + + // t2 = number of K256 superblocks + "mv t2, %[KBLKS] \n\t" + // t3 = number of K64 chunks per K256 superblock (256 / 64) + "li t3, 4 \n\t" + + // A pointers + "mv s2, %[pA] \n\t" // s2 = pA_superblock (a_scale at +0) + "addi s3, %[pA], 4+32 \n\t" // s3 = pA_qs (skip a_scale + a_sum[16]) + + // B pointers for nrow_block_q3_k<32> + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales[32*16]) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + // scales16 is at the end of the block: qs(2048) after hmask + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" // s8 = pB_scales16 (fp16 scales16[32]) + "mv s7, %[pB] \n\t" // s7 = pB_base (for next-block address calc) + + // v31: final FP32 accumulator for N=32 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ---- Preload B scales16[32] and build FP16 scale vector used by vmadot.hp ---- + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" // load fp16 scales16[32] + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" // broadcast/pack to match lanes + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" // v30 = scales16 * (1/16) + + // v24-v27: fp16 partial accumulators for a K64 chunk (vmadot.hp outputs) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + // HP vmadot: vle*10 vecIns*38 vmadot.hp*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" // loop over K256 superblocks + "K64_LPST%=: \n\t" // loop over 4 x K64 chunks + + // ------------------------------------------------------------ + // K0-15: load B scales + {hmask, qs} + A data; unpack and dot + // ------------------------------------------------------------ + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" // B int8 scales for this K16 + "addi s4, s4, 128 \n\t" + + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" // B hmask for this K16 + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v3, (s3) \n\t" // A int8 data for this K16 + "addi s3, s3, 64 \n\t" + + // Convert B int8 scales to FP16 and apply scales16*(1/16) + "vsetvli t0, x0, e8, m1 \n\t" + "vfwcvt.f.x.v v28, v2 \n\t" // int8 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v1, v28, v30 \n\t" // v1: FP16 scale vector for vmadot.hp + "vfmul.vv v29, v29, v30 \n\t" + + // Unpack B 2-bit qs + hmask -> signed int8 in v12..v15 + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + // (Next K16 unpack path uses a fresh hmask load) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Prepare another group from packed qs (bit shifts) + apply sign from hmask + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vsrl.vi v16, v8, 6 \n\t" + "vsrl.vi v17, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v18, v10, 6 \n\t" + "vsrl.vi v19, v11, 6 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v16, v16, -4, v0.t \n\t" + + // A shift for the second dot within this K64 + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + // Dot products with FP16 scaling (accumulate into v24..v27) + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v12, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v13, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v14, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v15, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v16, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v17, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v18, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v19, v1, 7, i8 \n\t" + + // (K32-47 / K48-63 blocks continue unchanged...) + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vmv.v.v v1, v29 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v3, v3, 4 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + + "vsrl.vi v20, v8, 6 \n\t" + "vsrl.vi v21, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v22, v10, 6 \n\t" + "vsrl.vi v23, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v20, v20, -4, v0.t \n\t" + + // K48-63 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v8, v4, 6 \n\t" + "vsrl.vi v9, v5, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v10, v6, 6 \n\t" + "vsrl.vi v11, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v8, v8, -4, v0.t \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v20, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v21, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v22, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v23, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v8, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v9, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v10, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v11, v1, 7, i8 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ---- End of K64 chunk: reduce fp16 accumulators -> fp32 and scale by A ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v24, v25, 1 \n\t" + "vpack.vv v14, v26, v27, 1 \n\t" + "vpack.vv v16, v12, v14, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v26, v16 \n\t" // fp16 -> fp32 vector (qsum * b_scales) + + // Load A scale and advance A pointer to next K256 superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // next B block base + "addi s3, s2, 4+32 \n\t" // reset A data pointer for next block + + // Advance B pointers to next K256 superblock + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" + "addi s7, t4, 0 \n\t" + "addi t2, t2, -1 \n\t" + + // Final per-block scaling: a_scale * 16.0f + "fmul.s f0, f0, %[a_post_mul] \n\t" + // acc += (qsum * b_scales) * (a_scale*16) + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vf v31, f0, v26 \n\t" + + "beqz t2, BLK_LPND%= \n\t" + + // Preload next block's scales16 and rebuild v30 for vmadot.hp + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" + + // Reset fp16 partial accumulators for next K64 loop(s) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [q3_step] "f"(k_q3k_scale_step), + [a_post_mul] "f"(k_a_scale_post_mul) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"); +#endif + } +} + +void gemm_kernel_i8i3k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q3_k<32>; + constexpr size_t NB_COLS = 32; //only support 32 in ASM + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t b_ncol_block_stride = sizeof(blk_type); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; + + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1* 4row 128bit + // Asum int16 * 16 4row 1024bit + // A M1K256 int8 4row 8192bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 16+128 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // v24-v27: K256 temp accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "vxor.vv v28, v0, v0 \n\t" // v28-v31: final accumulator + "vxor.vv v29, v0, v0 \n\t" + "vxor.vv v30, v0, v0 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ordinary vmadot: vle*13 vecIns*96 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // ========== K0-15: First K16 sub-block ========== + // Load B INT8 scale factors (32 cols × 16 K16 blocks) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v8, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // Load B quantized data (32 cols × 16 elements × 2bit, stored in 4 groups) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // Load B hmask (32 cols × 16bit sign mask) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Load A data (4 rows × 16 elements × INT8) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v12, (s3) \n\t" + "addi s3, s3, 256 \n\t" // Jump to next row + "vle8.v v13, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v14, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v15, (s3) \n\t" + "addi s3, s3, -768+64 \n\t" // Back to first row, advance 16 elements + + // Pack A data: merge 4 rows into 2 vectors + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v13, 1 \n\t" + "vpack.vv v18, v14, v15, 1 \n\t" + "vpack.vv v2, v16, v18, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" // 4 rows × cols 0-7 + "vmadot v18, v2, v13, i8 \n\t" // 4 rows × cols 8-15 + "vmadot v20, v2, v14, i8 \n\t" // 4 rows × cols 16-23 + "vmadot v22, v2, v15, i8 \n\t" // 4 rows × cols 24-31 + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" // Merge cols 0-15 + "vpack.vv v14, v20, v22, 2 \n\t" // Merge cols 16-31 + "vpack.vv v16, v12, v14, 3 \n\t" // Inter-row results (INT16) + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // INT8 → INT16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // INT16 → INT32 + + // Accumulate to K256 accumulator: qsum * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" // Row 0 + "vmacc.vv v25, v17, v23 \n\t" // Row 1 + "vmacc.vv v26, v18, v23 \n\t" // Row 2 + "vmacc.vv v27, v19, v23 \n\t" + + // ========== K16-31, K32-47, K48-63: Similar processing ========== + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 8 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 4 \n\t" + "vsll.vi v13, v5, 4 \n\t" + "vsll.vi v14, v6, 4 \n\t" + "vsll.vi v15, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" + "vmadot v18, v2, v13, i8 \n\t" + "vmadot v20, v2, v14, i8 \n\t" + "vmadot v22, v2, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 2 \n\t" + "vsll.vi v13, v5, 2 \n\t" + "vsll.vi v14, v6, 2 \n\t" + "vsll.vi v15, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v3, v3, 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ========== K256 superblock complete, apply scale factors ========== + // Load A's 4 row scale factors (FP32) + "flw f0, (s2) \n\t" + "flw f1, 4(s2) \n\t" + "flw f2, 8(s2) \n\t" + "flw f3, 12(s2) \n\t" + "add s2, s2, %[A_STR] \n\t" // Advance to next superblock + "add t4, s7, %[B_STR] \n\t" // t4 = next B block address + "addi s3, s2, (4+32)*4 \n\t" + + // Load B FP16 global scale factors (32 cols) + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (s6) \n\t" + + // Update B pointers to next block + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // ========== Type conversion and final scaling ========== + // FP16 → FP32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v9, v8 \n\t" + + // INT32 → FP32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + + // Compute a_scale * b_scale (4 rows) + "vfmul.vf v12, v9, f0 \n\t" + "vfmul.vf v13, v9, f1 \n\t" + "vfmul.vf v14, v9, f2 \n\t" + "vfmul.vf v15, v9, f3 \n\t" + + // Final accumulation: result += qsum * a_scale * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vv v28, v12, v24 \n\t" + "vfmacc.vv v29, v13, v25 \n\t" + "vfmacc.vv v30, v14, v26 \n\t" + "vfmacc.vv v31, v15, v27 \n\t" + + // Prepare for next K superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // Clear K256 accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + + // ========== Store results (4 rows × 32 cols) ========== + "mv t5, %[pC] \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v28, (%[pC]) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v29, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v30, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v31, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [A_STR] "r"(a_nrow_block_stride), [LDC] "r"(ldc * 4) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "f2", "f3", "s2", "s3", "s4", "s5", "s6", "s7"); + } +} + +void gemm_kernel_i8i4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... +#if 0 + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*7 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" // Bzp u8 -> u16 + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vfcvt.f.x.v v16, v26 \n\t" // zp i16 -> fp16 + "vadd.vi v18, v16, 0 \n\t" + "vadd.vi v20, v16, 0 \n\t" + "vadd.vi v22, v16, 0 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + // mac result * b_scale; f16*f16->f32 + "vfwmul.vv v31, v30, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + +#endif + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(uint8_t) + // b zp + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // Bzp uint8_t * N32 256bit; + // || scl*32..(fp16) | zp*32(uint8) | blk0 blk1 ... blk31 || scl*32..(fp16) ... + // || Element || Element ... + + //bias always be nullptr +#if 0 + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*6 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v31, (s4) \n\t" // B zp 32Row*uint8 = 256bit + "addi s4, s4, 32+128*4 \n\t" + + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v31, x0 \n\t" // Bzp u8 -> u16 + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vfwcvt.f.f.v v22, v30 \n\t" // b_scale fp16 -> fp32 + "vfcvt.f.x.v v18, v26 \n\t" // zp i16 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfwadd.vv v20, v18, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // mac result * b_scale; f32*f32->f32 + "vfmul.vv v31, v22, v20 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#endif + } + } +} + +void gemm_kernel_i8i4_hp_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" // init acc to zero + "mv t4, %[BK] \n\t" + "li t0, 0x4c00 \n\t" // 16 in fp16 + "fmv.h.x fa0, t0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + "li t5, 8 \n\t" + "addi t6, %[A], 288 \n\t" // point to blk scale + "flh ft1, (t6) \n\t" + "addi t6, %[A], 272 \n\t" // point to asum + + // init the acc fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v18, v18 \n\t" + "vxor.vv v17, v18, v18 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v18, v18 \n\t" + + "INNER_BLK_LOOP%=: \n\t" + // load a sum and scale + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A]) \n\t" + "addi %[A], %[A], 2 \n\t" + // load A + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A]) \n\t" // 1x32@i8 + "addi %[A], %[A], 32 \n\t" + + // load scale B and B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B]) \n\t" // b_scale fp16 + "addi %[B], %[B], 64 \n\t" + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" // scale b * scale a + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" // scale b * scale a * asm + "vfwmacc.vf v31, ft1, v10 \n\t" // asum * scale a * scale b * blk scale + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" // high 4 + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + "addi t5, t5, -1 \n\t" + "bgtz t5, INNER_BLK_LOOP%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t4, t4, -1 \n\t" + "vfwmacc.vf v31, ft1, v20 \n\t" + //"vsetvli t0, x0, e32, m1 \n\t" + //"vfmul.vf v31, v31, ft1 \n\t" // blk scale + + // update A ptr + "addi %[A], t6, 2 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "ft0", "ft1"); + } + } else { + // TODO: support quant_b_zp for i8i4 hp kernel + GGML_ABORT("gemm_kernel_i8i4_hp_m1 with quant_b_zp is not supported yet"); + } +} + +void gemm_kernel_i8i4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; +#if 0 + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vwmul.vx v10, v8, t1 \n\t" // 8*asum + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vrgather.vi v0, v10, 0 \n\t" + "vrgather.vi v1, v10, 1 \n\t" + "vrgather.vi v2, v10, 2 \n\t" + "vrgather.vi v3, v10, 3 \n\t" + + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v1 \n\t" + "vadd.vv v18, v18, v2 \n\t" + "vadd.vv v19, v19, v3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc*4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#else + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vsll.vi v8, v8, 3 \n\t" // asum * 8 + "vfcvt.f.x.v v9, v8 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vrgather.vi v10, v9, 0 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v16, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v17, v16, 4 \n\t" + "vnpack4.vv v12, v16, v17, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v16, v10, v10,0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v20, v16, v16,0 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vpack.vv v18, v20, v20, 0 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + "vfwmul.vv v18, v21, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#endif + } + } else { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2\n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + // load a sum + "lh s1, (%[A]) \n\t" + "lh s2, 2(%[A]) \n\t" + "lh s3, 4(%[A]) \n\t" + "lh s4, 6(%[A]) \n\t" + "addi %[A], %[A], 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v0, v10, s1 \n\t" + "vwmul.vx v2, v10, s2 \n\t" + "vwmul.vx v4, v10, s3 \n\t" + "vwmul.vx v6, v10, s4 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v2 \n\t" + "vadd.vv v18, v18, v4 \n\t" + "vadd.vv v19, v19, v6 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST]\n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3", "s1", "s2", "s3", "s4"); + } + } +} + +void gemm_kernel_i8i4_hp_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_SUBBLKS_PER_SUPERBLK = 8; + constexpr size_t K_SUBBLK_LEN = 32; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + GGML_ASSERT(count_m >= 4); + + // Contract: + // - computes a 4-row x 32-col tile per inner invocation + // - A is q8 HP packed in m4 layout, one logical K256 block at a time + // - B is q4 HP packed in N32 tiles, optionally with a separate zp area + // - tail-N is currently not handled here; the caller must provide full N32 tiles + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * K_SUBBLKS_PER_SUPERBLK + + (quant_b_zp ? NB_COLS * K_SUBBLKS_PER_SUPERBLK * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * 4; + const size_t a_subblk_stride = q8_hp_blk_size(K_SUBBLK_LEN, false, false) * 4; + + if (quant_b_zp != nullptr) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the with-zp path. + // + // A: M4 x K256 q8 HP block + // - split into 8 x K32 subblocks + // - each K32 subblock is 136B: + // 8B = 4 x fp16 row scales + // 128B = 4 x int8[32] row payloads + // - trailer after 8 subblocks is 72B: + // 4 rows x fp16[8] a_sum values, indexed as [row][ksi] + // 4 rows x fp16 scale_avg tail + // + // B: N32 x K256 q4 HP block with explicit zp area + // - each K32 subblock is 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload for 32 columns x 32 k-elements + // - zp is stored separately, not interleaved with the 576B payload block + // - one K256 superblock is laid out as: + // 8 x (scale + qs) blocks = 4608B + // 8 x zp[32] = 256B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // row1/row2/row3 are at +16/+32/+48 bytes + // - s5: current B (scale + qs) K32 subblock base + // - s6: current B zp[32] base for this ksi + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576, B_zp += 32 + // - per ki : skip the 72B A trailer and advance B to the next 4864B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + const _Float16 hp_scale_0125 = (_Float16) 0.125f; + + // VPR grouping used below: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : zp u8 / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "li t1, 4608 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + // load Bzp[32] for the current ksi from the dedicated zp area + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (s6) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + "vfwcvt.f.xu.v v10, v8 \n\t" // uint8 -> fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + "fmul.h ft1, ft1, %[HP0125] \n\t" + "fmul.h ft2, ft2, %[HP0125] \n\t" + "fmul.h ft3, ft3, %[HP0125] \n\t" + "fmul.h ft4, ft4, %[HP0125] \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> mf2 + "vfmul.vv v10, v10, v16 \n\t" // zp * ascale * bscale; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v10, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v10, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v10, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v10, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "addi s6, s6, 32 \n\t" // next zp[32] + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + "mv s5, s6 \n\t" // s6 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), + [HP1] "f"(hp_scale_1), [HP0125] "f"(hp_scale_0125) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s5", "s6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", + "memory"); + } + return; + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the no-zp path. + // + // A layout is identical to the with-zp branch. + // + // B: N32 x K256 q4 HP block without explicit zp storage + // - each K32 subblock is still 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload + // - zp is implicit and treated as a constant value 8 in the kernel + // - one K256 superblock therefore contains only: + // 8 x (scale + qs) blocks = 4608B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // - s5: current B (scale + qs) K32 subblock base + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576 + // - per ki : skip the 72B A trailer and advance B to the next 4608B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + + // VPR grouping used below matches the with-zp path: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : implicit zp lane / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v16, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v16, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v16, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v16, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" //N32in1register + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + // s5 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), [HP1] "f"(hp_scale_1) + : "t0", "t2", "t3", "t4", "t5", "t6", "s5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v10", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", "memory"); + } + return; + } +} + +void gemm_kernel_i8mxfp4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 1); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // MXFP4 no-zp: per column per k-block stride = scale_e8m0(1B) + qs(16B) + qh(4B) = 21B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // qh sign/high-bit mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // qs packed 4-bit magnitudes: n×k_blks×16 + n * k_blks * sizeof(uint8_t); // scale: n×k_blks×1 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (q8 block with per-block scale and stored sum field): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (u8×32) + // s5 = pB qh (sign/high-bit mask, 128B) + // s6 = pB qs (packed 4-bit magnitudes, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales u8 (loaded as bytes; later widened) + // v0 = qh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = qs unpack/pack temporaries (build signed vmadot lanes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32 \n\t" // s5 = pBh (pB + 32B scale) + "addi s6, %[pB], 32+128 \n\t" // s6 = pBs (pB + 32 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 672B = 32(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load qs (512B = 4 VRF) from s6, advance s6 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = qs N32K32 packed 4-bit magnitudes + "addi s6, s6, 128*4+128+32 \n\t" // s6 += 672 (512+128+32) + + // ---- load B scale (32B = 32×u8) from s4, advance s4 by 672 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_u8 × 32 + "addi s4, s4, 32+128*4+128 \n\t" // s4 += 672 (32+512+128) + + // ---- load qh (128B = 1 VRF) from s5, advance s5 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = qh N32K32 sign/high-bit packed + "addi s5, s5, 128+32+128*4 \n\t" // s5 += 672 (128+32+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + // ---- Decode packed MXFP4 payload into a vmadot-friendly signed-lane layout ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to 0 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- int8 dot products over the decoded MXFP4 lane groups ---- + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v2, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v2, 1 \n\t" + "vadd.vi v28, v2, -1 \n\t" + "vsll.vi v28, v28, 23 \n\t" + "vsll.vv v28, v30, v2, v0.t \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8mxfp4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 4); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + // MXFP4 block layout per K32/N32 tile: + // [scale_e8m0 x 32][qh sign/high-bit mask x 128B][qs packed 4-bit magnitudes x 512B] + // There is no explicit zp stream; qh is combined with qs to reconstruct signed MXFP4 values. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 32 \n\t" + "addi t2, %[B], 160 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "addi %[B], %[B], 672 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + + // Decode the packed MXFP4 payload once for the whole tile and expand it + // into a vmadot-friendly layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v26, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v26, 1 \n\t" + "vadd.vi v24, v26, -1 \n\t" + "vsll.vi v18, v24, 23 \n\t" + "vsll.vv v18, v30, v26, v0.t \n\t" + + // Row 0: dot(A0, decoded MXFP4 lane groups), accumulate in int32 and + // then apply A/B scaling. + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i5_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // ========================================================================= + // i8i5: 8-bit activation × 5-bit weight (4-bit low + 1-bit high mask) + // + // B layout per N32K32 k-block (no-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 191] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [192.. 703] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 704B per k-block stride + // + // B layout per N32K32 k-block (with-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 95 ] : zp_uint8 × 32 (32B) + // [96 .. 223] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [224.. 735] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 736B per k-block stride + // + // Bh format per N8K32 sub-block (32B): + // K rows × N cols × 1bit packed as bytes (8 cols per byte, K groups of 4B) + // Byte k gives 8 mask bits for columns N7..N0 at k-th K-element. + // + // Computation: + // B5bit_signed = (Bs | (Bh << 4)) - zp + // dot(A, B5) = dot(A, Bs_u4) + 16*dot(A, Bh_u1) - zp*asum + // No-zp: implicit zp = 16 (unsigned [0..31] centered at 16) + // With-zp: explicit zp from data + // + // ========================================================================= + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 no-zp: per column per k-block stride = fp16(2B) + i4(16B) + i1(4B) = 22B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // Bh i1 mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // Bs i4 data: n×k_blks×16 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32) + // s5 = pB Bh (i1 mask, 128B) + // s6 = pB Bs (i4 packed, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBh (pB + 64B scale) + "addi s6, %[pB], 32*2+128 \n\t" // s6 = pBs (pB + 64 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 704B = 64(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+64 \n\t" // s6 += 704 (512+128+64) + + // ---- load B scale (64B = 32×fp16) from s4, advance s4 by 704 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64+128*4+128 \n\t" // s4 += 704 (64+512+128) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+64+128*4 \n\t" // s5 += 704 (128+64+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + //// vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 with-zp: per column per k-block stride = fp16(2B)+zp(1B)+i4(16B)+i1(4B)=23B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // Bs i4: n×k_blks×16 + n * k_blks * (blk_len / 8) + // Bh i1: n×k_blks×4 + n * k_blks * sizeof(uint8_t) + // zp: n×k_blks×1 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32); 每个 k-block 先 +64 指向 zp,再 +672 到下一个 block + // s5 = pB Bh (i1 mask, 128B) (offset +96) + // s6 = pB Bs (i4 packed, 512B) (offset +224) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) / later reused to hold Bzp bytes + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBh (pB + 64B scale + 32B zp = pB+96) + "addi s6, %[pB], 32*3+128 \n\t" // s6 = pBs (pB + 96 + 128 = pB+224) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 736B = 64(scale) + 32(zp) + 128(Bh) + 512(Bs) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+96 \n\t" // s6 += 736 (512+128+96) + + // ---- load B scale (64B = 32×fp16) from s4; then s4 points to zp[32] ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64 \n\t" // s4 += 64 (now points to zp) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+96+128*4 \n\t" // s5 += 736 (128+96+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4+128 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8i5_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + + GGML_UNUSED(count_m); + GGML_UNUSED(blk_len); + + // This kernel computes a 4x32 output tile. For each K32 block we decode the + // packed Q5 weights once and reuse the decoded vectors across the 4 A rows. + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + if (quant_b_zp) { + // Q5_1 block layout per K32/N32 tile: + // [scale_fp16 x 32][zp_u8 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> zp stream, t2 -> qh bitmask stream, s5 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 96 \n\t" + "addi s5, %[B], 224 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t2) \n\t" + "vl4r.v v8, (s5) \n\t" + "addi %[B], %[B], 736 \n\t" + + // Decode Q5 payload once for the whole tile: + // 1) split `qs` low/high nibbles, + // 2) repack into bytes, + // 3) use the `qh` mask to inject bit4 (+16) where needed, + // 4) expand into the vmadot-friendly layout reused by all 4 rows. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v3, (t1) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * zp, then scale by A/B scales. + // The widen/mul correction sequence intentionally matches the proven m1 Q5_1 path. + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vwaddu.vx v28, v3, x0 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v12, v28, s1 \n\t" + "vwmul.vx v14, v28, s2 \n\t" + "vwmul.vx v20, v28, s3 \n\t" + "vwmul.vx v22, v28, s4 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v12 \n\t" + "vadd.vv v25, v25, v14 \n\t" + "vadd.vv v26, v26, v20 \n\t" + "vadd.vv v27, v27, v22 \n\t" + "vfcvt.f.x.v v12, v24 \n\t" + "vfcvt.f.x.v v14, v25 \n\t" + "vfcvt.f.x.v v20, v26 \n\t" + "vfcvt.f.x.v v22, v27 \n\t" + "vfmul.vv v12, v12, v18 \n\t" + "vfmul.vv v14, v14, v18 \n\t" + "vfmul.vv v20, v20, v18 \n\t" + "vfmul.vv v22, v22, v18 \n\t" + "vfmacc.vf v4, fa0, v12 \n\t" + "vfmacc.vf v5, fa1, v14 \n\t" + "vfmacc.vf v6, fa2, v20 \n\t" + "vfmacc.vf v7, fa3, v22 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { + // Q5_0 block layout per K32/N32 tile: + // [scale_fp16 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + // There is no explicit zp stream; the implicit midpoint correction is +16. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 192 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + "addi %[B], %[B], 704 \n\t" + + // Decode Q5 payload once for the whole tile and expand it into the vmadot layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * 16 (implicit Q5_0 midpoint correction). + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "slli s1, s1, 4 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "slli s2, s2, 4 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "slli s3, s3, 4 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "slli s4, s4, 4 \n\t" + "vadd.vx v24, v24, s1 \n\t" + "vadd.vx v25, v25, s2 \n\t" + "vadd.vx v26, v26, s3 \n\t" + "vadd.vx v27, v27, s4 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i8_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 2048bit + // 4VRF, N32K32, 8192bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... + + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*64 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4 \n\t" + "vl4r.v v8, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*8 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "addi s2, s2, 6+32 \n\t" // AScale + Asum(FP32+i16) + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v28, v8, v9, 1 \n\t" + "vupack.vv v30, v10, v11, 1 \n\t" + + "vslidedown.vi v4, v3, 4 \n\t" + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v24, i8 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadot v18, v3, v26, i8 \n\t" // M0 N8 - N15 + "vmadot v20, v3, v28, i8 \n\t" // M0 N16 - N23 + "vmadot v22, v3, v30, i8 \n\t" // M0 N24 - N31 + + "vmadot v16, v4, v25, i8 \n\t" + "vmadot v18, v4, v27, i8 \n\t" + "vmadot v20, v4, v29, i8 \n\t" + "vmadot v22, v4, v31, i8 \n\t" + + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i8_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = k_blks * sizeof(ggml_fp16_t) + k_blks * blk_len; + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16+8 \n\t" // Ascl+Asum; FP32*4+i16*4 + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + "vl4r.v v8, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v2, v0, v0, 1 \n\t" + + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v4, v8, v9, 1 \n\t" + "vupack.vv v6, v10, v11, 1 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot v16, v2, v24, i8 \n\t" + "vmadot v18, v2, v26, i8 \n\t" + "vmadot v20, v2, v4, i8 \n\t" + "vmadot v22, v2, v6, i8 \n\t" + "vmadot v16, v3, v25, i8 \n\t" + "vmadot v18, v3, v27, i8 \n\t" + "vmadot v20, v3, v5, i8 \n\t" + "vmadot v22, v3, v7, i8 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz %[BK], BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +} + +void moe_m2_gemm_kernel_i8i4_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, + ldc); +#else + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.x v10, t1 \n\t" + "vmv.v.x v11, t2 \n\t" + + "vpack.vv v18, v10, v11, 1 \n\t" + "vsll.vi v18, v18, 3 \n\t" // mul 8 + "vfcvt.f.x.v v18, v18 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vor.vv v19, v18, v18 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { +# if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +# else + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + // asum*zp + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v2, v10, t1 \n\t" + "vwmul.vx v4, v10, t2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "vfcvt.f.x.v v2, v2 \n\t" + "vfcvt.f.x.v v4, v4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwcvt.f.f.v v16, v20 \n\t" + + "vfwcvt.f.f.v v18, v14 \n\t" + + // +asum*zp + "vsetvli t0, x0, e32, m1 \n\t" + "vfadd.vv v16, v16, v2 \n\t" + "vfadd.vv v17, v17, v4 \n\t" + "vfmul.vv v16, v16, v18 \n\t" + "vfmul.vv v17, v17, v18 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +# endif + } +#endif +} + +void moe_m2_gemm_kernel_i8i5_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + GGML_UNUSED(blk_len); + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/Bh/Bs and advance to the next q5_0 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" // v0 = Bh N32K32 1-bit packed + "addi %[B], %[B], 128 \n\t" + "vl4r.v v8, (%[B]) \n\t" // v8..v11 = Bs N32K32 i4 + "addi %[B], %[B], 512 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + "slli t3, t3, 4 \n\t" + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + "vadd.vx v25, v25, t3 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v1 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + "addi t5, %[B], 64 \n\t" // t5 = zp (32B) + "addi t6, %[B], 96 \n\t" // t6 = qh (128B) + "addi s1, %[B], 224 \n\t" // s1 = qs (512B) + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/zp/Bh/Bs and advance to the next q5_1 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 736 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t6) \n\t" // v0 = Bh N32K32 1-bit packed + "addi t6, t6, 736 \n\t" + "vl4r.v v8, (s1) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s1, s1, 736 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // q5_1 uses explicit zp, so keep a_sum unshifted here. + // [4*32]*2 + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (t5) \n\t" // v4 = Bzp N32 uint8 + "addi t5, t5, 736 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v4, x0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + "vwmul.vx v31, v28, t3 \n\t" + + // b_scale fp16 -> fp32 + "vfwcvt.f.f.v v28, v1 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + "vadd.vv v25, v25, v31 \n\t" + + // a_scale * b_scale; + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } +} + +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i2k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i2k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i2k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, + ldc); +#else + gemm_kernel_i8i2k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i3k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i3k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + moe_m2_gemm_kernel_i8i4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i8_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i8_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + //moe_m2_gemm_kernel_i8mxfp4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i5_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i5_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i5_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + moe_m2_gemm_kernel_i8i5_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 2; +} + +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.cpp b/ggml/src/ggml-cpu/spacemit/ime_env.cpp new file mode 100644 index 00000000000..a13ba391da2 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.cpp @@ -0,0 +1,320 @@ +#include "ime_env.h" + +#include "ggml-impl.h" +#include "spine_mem_pool.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +bool spine_core_info::get_spine_core_info(std::vector & result) { + static std::unordered_map spine_march_mapping_ = { + {0x8000000058000001, spine_core_arch_id::core_arch_x60 }, + { 0x8000000041000001, spine_core_arch_id::core_arch_a60 }, + { 0x8000000058000002, spine_core_arch_id::core_arch_x100}, + { 0x8000000041000002, spine_core_arch_id::core_arch_a100}, + }; + + result.clear(); + std::ifstream file("/proc/cpuinfo"); + std::string line; + + std::vector> cpu_info_list; + + uint64_t current_processor = spine_invalid_core_id; + uint64_t current_marchid = 0; + bool has_processor = false; + bool has_marchid = false; + + if (!file.is_open()) { + return false; + } + + while (std::getline(file, line)) { + if (line.substr(0, 9) == "processor") { + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + current_processor = std::stoi(line.substr(colon_pos + 1)); + has_processor = true; + } + + has_marchid = false; + } else if (line.substr(0, 7) == "marchid") { + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + std::string marchid_str = line.substr(colon_pos + 1); + marchid_str.erase(std::remove_if(marchid_str.begin(), marchid_str.end(), isspace), marchid_str.end()); + current_marchid = std::stoull(marchid_str, nullptr, 16); + has_marchid = true; + } + } + } + + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + if (has_processor && has_marchid) { + for (auto & cpu_info : cpu_info_list) { + if (cpu_info[0] != spine_invalid_core_id && + spine_march_mapping_.find(cpu_info[1]) != spine_march_mapping_.end()) { + auto core_info = spine_core_info(); + core_info.core_id = cpu_info[0]; + core_info.arch_id = spine_core_arch_id(spine_march_mapping_[cpu_info[1]]); + + result.push_back(core_info); + } + } + } + + return has_processor && has_marchid; +} + +namespace { +uint16_t hex_string_to_u16(const std::string & hex_str) { + try { + size_t pos = 0; + if (hex_str.substr(0, 2) == "0x" || hex_str.substr(0, 2) == "0X") { + pos = 2; + } + unsigned long result = std::stoul(hex_str.substr(pos), nullptr, 16); + if (result > std::numeric_limits::max()) { + throw std::out_of_range("Converted value is out of range for uint16_t"); + } + return static_cast(result); + } catch (const std::invalid_argument & e) { + throw std::invalid_argument("Invalid hexadecimal string"); + } catch (const std::out_of_range & e) { + throw; + } +} + +const char * spine_mem_pool_backend_to_string(spine_mem_pool_backend backend) { + switch (backend) { + case spine_mem_pool_backend::none: + return "NONE"; + case spine_mem_pool_backend::posix_memalign: + return "POSIX"; + case spine_mem_pool_backend::transparent_hugepage: + return "HPAGE"; + case spine_mem_pool_backend::hugetlb_1g: + return "HPAGE1GB"; + } + + return "unknown"; +} + +spine_mem_pool_backend parse_mem_backend(const char * mem_backend_str) { + if (mem_backend_str == nullptr || mem_backend_str[0] == '\0') { + return spine_mem_pool_backend::transparent_hugepage; + } + + std::string value(mem_backend_str); + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast(std::tolower(ch)); }); + + if (value == "none") { + return spine_mem_pool_backend::none; + } + + if (value == "posix") { + return spine_mem_pool_backend::posix_memalign; + } + + if (value == "hpage") { + return spine_mem_pool_backend::transparent_hugepage; + } + + if (value == "hpage1gb") { + return spine_mem_pool_backend::hugetlb_1g; + } + + throw std::runtime_error("invalid SPACEMIT_MEM_BACKEND: " + value + ", expected NONE, POSIX, HPAGE or HPAGE1GB"); +} +} // namespace + +spine_env_info::spine_env_info() { + num_cores = static_cast(std::thread::hardware_concurrency()); + spine_core_info::get_spine_core_info(core_info_list); + + // special for x60 K1 + if (core_info_list.size() == 8 && core_info_list[0].arch_id == spine_core_arch_id::core_arch_x60) { + for (int i = 0; i < 4; i++) { + core_info_list[i].arch_id = spine_core_arch_id::core_arch_a60; + } + } + + // special for qemu + if (core_info_list.size() == 0) { + char * spine_core_arch_str = getenv("SPACEMIT_CORE_ARCH"); + if (spine_core_arch_str != nullptr) { + auto arch_id = hex_string_to_u16(spine_core_arch_str); + for (int i = 0; i < num_cores; i++) { + auto core_info = spine_core_info(); + core_info.core_id = i; + core_info.arch_id = spine_core_arch_id{ arch_id }; + core_info_list.push_back(core_info); + } + } + } + + if (core_info_list.size() == 0) { + throw std::runtime_error( + "Failed to get SPACEMIT_CORE_ARCH from environment or failed to parse it from /proc/cpuinfo"); + } + + char * spine_perfer_core_arch_str = getenv("SPACEMIT_PERFER_CORE_ARCH"); + if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str != "") { + perfer_core_arch_id = spine_core_arch_id{ hex_string_to_u16(spine_perfer_core_arch_str) }; + } + + char * spine_perfer_core_id_str = getenv("SPACEMIT_PERFER_CORE_ID"); + std::vector perfer_core_id_vec; + if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str != "") { + std::string perfer_core_id_str(spine_perfer_core_id_str); + size_t start = 0; + size_t end = 0; + while ((end = perfer_core_id_str.find(',', start)) != std::string::npos) { + std::string core_id_substr = perfer_core_id_str.substr(start, end - start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + start = end + 1; + } + std::string core_id_substr = perfer_core_id_str.substr(start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + } + + perfer_core_ids.reserve(num_cores); + if (perfer_core_arch_id == spine_core_arch_id::core_arch_none) { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + num_perfer_cores++; + perfer_core_arch_id = core_arch_id; + cpu_mask |= (1ULL << core_info.core_id); + perfer_core_ids.push_back(core_info.core_id); + } + } + } else { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + num_perfer_cores++; + cpu_mask |= (1ULL << core_info.core_id); + + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + perfer_core_ids.push_back(core_info.core_id); + } + } + } + if (num_perfer_cores == 0) { + GGML_ABORT("can not find core with arch id %x for SPACEMIT_PERFER_CORE_ARCH in core info list\n", + (uint16_t) perfer_core_arch_id); + } + } + + if (perfer_core_id_vec.size() > 0) { + perfer_core_ids.clear(); + cpu_mask = 0; + num_perfer_cores = 0; + for (int core_id : perfer_core_id_vec) { + if (core_id < 0 || core_id >= num_cores) { + GGML_ABORT("invalid core id in SPACEMIT_PERFER_CORE_ID: %d, should be between 0 and %d\n", core_id, + num_cores - 1); + } + auto core_info = core_info_list[core_id]; + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + cpu_mask |= (1ULL << core_id); + perfer_core_ids.push_back(core_id); + } else { + GGML_ABORT( + "core id %d in SPACEMIT_PERFER_CORE_ID has arch id %x which does not match " + "SPACEMIT_PERFER_CORE_ARCH %x\n", + core_id, (uint16_t) core_arch_id, (uint16_t) perfer_core_arch_id); + } + } + std::string perfer_core_id_vec_str; + for (int core_id : perfer_core_id_vec) { + perfer_core_id_vec_str += std::to_string(core_id) + ","; + } + perfer_core_id_vec_str.pop_back(); + GGML_LOG_DEBUG("SPACEMIT_PERFER_CORE_ID is set, perferred core ids: %s\n", perfer_core_id_vec_str.c_str()); + num_perfer_cores = static_cast(perfer_core_id_vec.size()); + } + + use_ime1 = perfer_core_arch_id == spine_core_arch_id::core_arch_a60 || + perfer_core_arch_id == spine_core_arch_id::core_arch_x100; + + use_ime2 = perfer_core_arch_id == spine_core_arch_id::core_arch_a100; + + mem_backend = parse_mem_backend(getenv("SPACEMIT_MEM_BACKEND")); + char * spine_disable_tcm_str = getenv("SPACEMIT_DISABLE_TCM"); + auto user_disable_tcm = spine_disable_tcm_str != nullptr && strcmp(spine_disable_tcm_str, "0") != 0; + + if (!user_disable_tcm) { + spine_mem_pool_tcm_info tcm_info; + if (spine_mem_pool_tcm_init(&tcm_info)) { + use_tcm = tcm_info.available; + tcm_blk_size = tcm_info.blk_size; + GGML_LOG_DEBUG("CPU_RISCV64_SPACEMIT: tcm is available, blk_size: %zu, blk_num: %zu, is_fake_tcm: %d\n", + tcm_info.blk_size, tcm_info.blk_num, tcm_info.is_fake_tcm); + + for (auto & core_info : core_info_list) { + auto core_arch_head = (uint16_t) (core_info.arch_id) >> 12; + if (core_arch_head != 0xA) { + aicpu_id_offset++; + } else { + break; + } + } + } + } + + GGML_LOG_DEBUG( + "CPU_RISCV64_SPACEMIT: num_cores: %d, num_perfer_cores: %d, perfer_core_arch_id: %x, exclude_main_thread: %d, " + "use_ime1: %d, use_ime2: %d, mem_backend: %s, cpu_mask: %lx, aicpu_id_offset: %d\n", + num_cores, num_perfer_cores, (uint16_t) perfer_core_arch_id, exclude_main_thread, use_ime1, use_ime2, + spine_mem_pool_backend_to_string(mem_backend), cpu_mask, aicpu_id_offset); + + const size_t init_barrier_size = sizeof(spine_barrier_t) * spine_init_barrier_count; + init_barrier = + static_cast(spine_mem_pool_shared_mem_alloc(init_barrier_size, alignof(spine_barrier_t))); + if (init_barrier != nullptr) { + init_barrier_is_shared_mem = true; + } else { + GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n", + __func__); + init_barrier = new spine_barrier_t[spine_init_barrier_count]; + } + + spine_barrier_init(init_barrier, spine_init_barrier_count, 2); +} + +spine_env_info::~spine_env_info() { + if (init_barrier_is_shared_mem) { + spine_mem_pool_shared_mem_free(init_barrier); + } else { + delete[] init_barrier; + } + + init_barrier = nullptr; + init_barrier_is_shared_mem = false; +} + +spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.h b/ggml/src/ggml-cpu/spacemit/ime_env.h new file mode 100644 index 00000000000..a6ca06d26a4 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.h @@ -0,0 +1,55 @@ +#pragma once + +#include "spine_barrier.h" +#include "spine_mem_pool.h" + +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +constexpr uint64_t spine_invalid_core_id = 0xFFFFFFFF; +constexpr size_t spine_init_barrier_count = 16; + +enum class spine_core_arch_id : uint16_t { + core_arch_none = 0, + core_arch_x60 = 0x503C, + core_arch_x100 = 0x5064, + core_arch_x200 = 0x50C8, + core_arch_a60 = 0xA03C, + core_arch_a100 = 0xA064, + core_arch_a200 = 0xA0C8, +}; + +struct spine_core_info { + uint64_t core_id{ spine_invalid_core_id }; + spine_core_arch_id arch_id{ spine_core_arch_id::core_arch_none }; + + static bool get_spine_core_info(std::vector & result); +}; + +struct spine_env_info { + std::vector core_info_list; + std::vector perfer_core_ids; + int aicpu_id_offset{ 0 }; + int num_cores{ 0 }; + int num_perfer_cores{ 0 }; + spine_core_arch_id perfer_core_arch_id{ spine_core_arch_id::core_arch_none }; + bool exclude_main_thread{ false }; + bool use_ime2{ false }; + bool use_ime1{ false }; + bool use_tcm{ false }; + spine_mem_pool_backend mem_backend{ spine_mem_pool_backend::transparent_hugepage }; + uint64_t tcm_blk_size{ 0 }; + uint64_t cpu_mask{ 0 }; + spine_barrier_t * init_barrier{ nullptr }; + bool init_barrier_is_shared_mem{ false }; + + spine_env_info(); + ~spine_env_info(); +}; + +extern spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 75706341505..0a1fafffb25 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,26 +1,189 @@ #pragma once +#include #include +#include + +namespace spacemit_kernels { + +#define BLOCK_QNK_LEN 256 + +template struct nrow_block_q2_k { + // [4bit scale + 4bit zp] * N * 16 + uint8_t scales[N * BLOCK_QNK_LEN / 16]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; + uint16_t zeros16[N]; +}; + +template struct nrow_block_q3_k { + // [8bit scale] * N * 16 + int8_t scales[N * 16]; + // [b0, b1, b2, b3, b4, b5, b6, b7] ... [b248, b249, b250, b251, b252, b253, b254, b255] + uint8_t hmask[N * BLOCK_QNK_LEN / 8]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; +}; + +template struct nrow_block_mxfp4 { + uint8_t e[N]; + uint8_t qh[4 * N]; + uint8_t qs[16 * N]; +}; + +template struct __attribute__((packed)) nrow_block_q5_1 { + uint16_t scales16[N]; + uint8_t zp[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_1<1>) == sizeof(uint8_t) + 22, "wrong nrow_block_q5_1 block size/padding"); + +template struct __attribute__((packed)) nrow_block_q5_0 { + uint16_t scales16[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_0<1>) == 22, "wrong nrow_block_q5_0 block size/padding"); + +using gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t *, const uint8_t *, const uint8_t *, float *, size_t, size_t, size_t, size_t)>; + +using moe_gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t **, const uint8_t *, const uint8_t *, float **, size_t, size_t, size_t, size_t)>; -namespace sqnbitgemm_spacemit_ime { namespace ime1 { -size_t gemm_kernel_i8i4(size_t blk_len, - const std::byte * quant_a_ptr, - const std::byte * quant_b_data, - const float * quant_b_scale, - const std::byte * quant_b_zp, - float * c_ptr, - size_t count_m, - size_t count_n, - size_t count_k, - size_t block_count_k, - size_t ldc, - const float * bias, - const size_t scale_stride); - -void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); - -void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime + +namespace ime2 { +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/repack.cpp b/ggml/src/ggml-cpu/spacemit/repack.cpp new file mode 100644 index 00000000000..3c879c4b7a0 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.cpp @@ -0,0 +1,1795 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "repack.h" + +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ime_kernels.h" + +#include +#include +#include +#include + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#else +#error "riscv not enabled in this build" +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// clang-format on + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); + +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block_with_zp<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16 + 32 * sizeof(uint8_t), + "wrong block_with_zp<4,32> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +using block_q4_0x32 = block<4, 32>; +using block_q4_1x32 = block_with_zp<4, 32>; +using block_q8_0x32 = block<8, 32>; + +struct block_q4_0x32x256 { + block_q4_0x32 blocks[8]; // [f16 * 32 | i4 * 32 * 32] * 8 +}; + +struct block_q4_1x32x256 { + block_q4_0x32 blocks[8]; + uint8_t zps[32 * 8]; +}; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_q4_0x32 make_block_q4_0x32(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x32 out; + assert(QK4_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b17] ......... [b30 b31] + out.qs[i * QK4_0 / 2 + QK4_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q4_1x32 make_block_q4_1x32(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x32 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[i * QK4_1 / 2 + QK4_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q8_0x32 make_block_q8_0x32(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x32 out; + GGML_ASSERT(QK8_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * QK8_0, in[i].qs, QK8_0); + } + + return out; +} + +static int repack_q2_k_to_q2_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const block_q2_K * src = (const block_q2_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q2_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint8_t qs_aux[256] = { 0 }; + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q2_K * src_block = &src[(b + i) * nblocks + x]; + + // scale for [16, N] + for (int j = 0; j < 16; j++) { + auto zp_aux = (dst->scales[j * nrows_interleaved + i]) & 0xF0; + + dst->scales[j * nrows_interleaved + i] = (src_block->scales[j] & 0x0F) | zp_aux; + } + + // zp for [N, 16] + for (int j = 0; j < 16; j++) { + auto scale_aux = (dst->scales[16 * i + j]) & 0x0F; + + dst->scales[16 * i + j] = (src_block->scales[j] & 0xF0) | scale_aux; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + dst->scales16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + dst->zeros16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + dst++; + } + } + + return 0; +} + +static int repack_q3_k_to_q3_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * src = (const block_q3_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q3_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint32_t b_scale_aux[4] = { 0 }; + uint8_t qs_aux[256] = { 0 }; + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q3_K * src_block = &src[(b + i) * nblocks + x]; + + uint32_t * auxs = b_scale_aux; + int8_t * scale = (int8_t *) auxs; + memcpy(auxs, src_block->scales, 12); + + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + for (int j = 0; j < 16; j++) { + dst->scales[j * nrows_interleaved + i] = scale[j] - 32; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + //memcpy(dst->hmask + i * 32, src_block->hmask, 32); + + // from nrows_interleaved * [32byte] + // to 16 * [nrows_interleaved * uint16_t] + uint16_t * dst_mask = ((uint16_t *) dst->hmask) + i; + for (int j = 0; j < 16; j++, dst_mask += nrows_interleaved) { + uint8_t b_shift = j / 2; + uint8_t * b_mask_col = (uint8_t *) (src_block->hmask + (j % 2) * 16); + // b0 - b15 + uint16_t msk_out_0 = 0; + + for (int k = 0; k < 8; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + for (int k = 8; k < 16; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + + dst_mask[0] = msk_out_0; + } + + dst->scales16[i] = src_block->d; + } + + dst++; + } + } + + return 0; +} + +static int repack_q4_0_to_q4_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32x256 * dst = (block_q4_0x32x256 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + dst->blocks[j] = make_block_q4_0x32(dst_tmp, interleave_block); + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_1_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32x256 * dst = (block_q4_1x32x256 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + + block_q4_0x32 * dst_block = &dst->blocks[j]; + uint8_t * dst_zp = dst->zps + j * nrows_interleaved; + + for (int i = 0; i < nrows_interleaved; i++) { + float d = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + + dst_block->d[i] = GGML_FP32_TO_FP16(d); + dst_zp[i] = static_cast(mid); + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + k] = + (dst_tmp[i].qs[k * 2] & 0x0F) | ((dst_tmp[i].qs[k * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + QK4_1 / 4 + k] = + ((dst_tmp[i].qs[k * 2] & 0xF0) >> 4) | (dst_tmp[i].qs[k * 2 + 1] & 0xF0); + } + } + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_0_to_q4_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack. +static int repack_q4_0_to_q4_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_0 / 2; // 16 + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + // d is at offset 0 of each block_q4_0, stride between rows = row_stride + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + // For each row i: + // src qs[16]: [b0|b16] [b1|b17] ... [b15|b31] (lo nibble = b_j, hi nibble = b_{j+16}) + // dst qs low 8B: (qs[2j] & 0x0F) | ((qs[2j+1] & 0x0F) << 4) for j=0..7 + // dst qs high 8B: ((qs[2j] >> 4)) | (qs[2j+1] & 0xF0) for j=0..7 + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); // qs[0], qs[2], ..., qs[14] + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); // qs[1], qs[3], ..., qs[15] + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_1_to_q4_1_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack + zp computation. +static int repack_q4_1_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_1 / 2; // 16 + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_1); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_1 * col_src = src + x; + + // --- 1) Gather d and m, compute zp = clamp(nearbyint(-m/d), 0, 15) --- + // block_q4_1 layout: [d(f16), m(f16), qs[16]] + // d is at byte offset 0, m is at byte offset 2 from each block start + { + const uint8_t * dm_base = (const uint8_t *) &col_src->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + ggml_half * d_dst = dst->d; + uint8_t * zp_dst = dst->zp; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + + // stride load d (f16) from each row + vuint16m1_t vd_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd_raw, vl); + + // stride load m (f16) from each row (offset +2 bytes from d) + vuint16m1_t vm_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + 2 + offset * row_stride), row_stride, vl); + + // convert to f32 for zp computation: zp = nearbyint(-m / d) + vfloat16m1_t vd_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vd_raw); + vfloat16m1_t vm_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vm_raw); + + // -m / d in f16 directly (SpaceMIT X60 supports f16 arithmetic) + vfloat16m1_t v_neg_m = __riscv_vfneg_v_f16m1(vm_f16, vl); + vfloat16m1_t v_ratio = __riscv_vfdiv_vv_f16m1(v_neg_m, vd_f16, vl); + + // Convert to f32 for nearbyint, then clamp + vfloat32m2_t v_ratio_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_ratio, vl); + + // Use integer rounding: convert f32 -> int (rounds to nearest) + vint32m2_t v_zp_i32 = __riscv_vfcvt_x_f_v_i32m2(v_ratio_f32, vl); + + // clamp to [0, 15] + v_zp_i32 = __riscv_vmax_vx_i32m2(v_zp_i32, 0, vl); + v_zp_i32 = __riscv_vmin_vx_i32m2(v_zp_i32, 15, vl); + + // narrow i32 -> u8 + vint16m1_t v_zp_i16 = __riscv_vncvt_x_x_w_i16m1(v_zp_i32, vl); + vint8mf2_t v_zp_i8 = __riscv_vncvt_x_x_w_i8mf2(v_zp_i16, vl); + vuint8mf2_t v_zp_u8 = __riscv_vreinterpret_v_i8mf2_u8mf2(v_zp_i8); + __riscv_vse8_v_u8mf2(zp_dst + offset, v_zp_u8, vl); + + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_k_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q6_k_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q8_0 dst_tmp[32]; + int8_t aux8[QK4_1]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrow_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + int i = 0; + for (; i < nrow_real; i++) { + const uint8_t * q4 = src[x + i * nblocks].ql; + const uint8_t * qh = src[x + i * nblocks].qh; + const int8_t * scales = src[x + i * nblocks].scales; + float d = GGML_FP16_TO_FP32(src[x + i * nblocks].d); + + q4 += 64 * (bi / 4); + qh += 32 * (bi / 4); + int8_t * GGML_RESTRICT a = aux8; + + int8_t bi_idx = bi % 4; + + if (bi_idx == 0) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + } + } else if (bi_idx == 1) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + } + } else if (bi_idx == 2) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + } + } else if (bi_idx == 3) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + } + a = aux8; + + float a_max_abs = 0.0f; + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + for (int l = 0; l < 16; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_0)); + } + + for (int l = 16; l < 32; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_1)); + } + + float reflect_scale = a_max_abs / ((1 << 7) - 1); + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + for (int l = 0; l < 16; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_0), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + for (int l = 16; l < 32; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_1), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + dst_tmp[i].d = GGML_FP32_TO_FP16(reflect_scale); + + memcpy(dst_tmp[i].qs, a, 32 * sizeof(int8_t)); + } + + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q6_k_to_q8_0_32_bl +// Vectorizes the Q6_K dequant -> requant pipeline using RVV intrinsics. +// For each sub-block (bi), dequant 32 Q6_K values to int6 -> apply two sub-block scales -> +// find max abs -> compute reflect_scale -> requant to int8 -> gather d with stride load. +static int repack_q6_k_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q6_K); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + // --- 1) Gather 32 d values with stride load --- + // We need to compute reflect_scale per row first, so gather d later. + // Process each row: dequant Q6_K sub-block -> requant to Q8_0 + for (int i = 0; i < nrows_interleaved; i++) { + const block_q6_K * src_blk = &src[x + i * nblocks]; + const uint8_t * q4 = src_blk->ql + 64 * (bi / 4); + const uint8_t * qh = src_blk->qh + 32 * (bi / 4); + const int8_t * scales = src_blk->scales; + float d = GGML_FP16_TO_FP32(src_blk->d); + + int8_t bi_idx = bi % 4; + + // --- Dequant 32 Q6_K values to int6 (range [-32, 31]) using RVV --- + // vl = 32 for e8m2 (VLEN=256) or loop for smaller VLEN + const size_t vl16 = __riscv_vsetvl_e8m1(16); + + vint8m1_t va_lo, va_hi; // 16 elements each + + if (bi_idx == 0) { + // a[l] = (q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_lo, 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_hi, 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 1) { + // a[l] = (q4[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 2, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 2, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 2) { + // a[l] = (q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 4, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 4, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else { // bi_idx == 3 + // a[l] = (q4[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 6, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 6, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } + + // --- Widen to i16 for scaled abs computation --- + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + + // Widen i8 -> i16 -> f32 for abs*scale computation + vint16m2_t va_lo_w = __riscv_vsext_vf2_i16m2(va_lo, vl16); + vint16m2_t va_hi_w = __riscv_vsext_vf2_i16m2(va_hi, vl16); + + // Compute |a[l] * scale_0| for lo half, |a[l] * scale_1| for hi half + vfloat32m4_t vf_lo = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_lo_w, vl16), vl16); + vfloat32m4_t vf_hi = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_hi_w, vl16), vl16); + + vfloat32m4_t vabs_lo = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_lo, scale_0, vl16), vl16); + vfloat32m4_t vabs_hi = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_hi, scale_1, vl16), vl16); + + // Find max abs across both halves + vfloat32m4_t vabs_max = __riscv_vfmax_vv_f32m4(vabs_lo, vabs_hi, vl16); + + // Reduce to scalar max + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m1_t vmax_red = __riscv_vfredmax_vs_f32m4_f32m1(vabs_max, vzero, vl16); + float a_max_abs = __riscv_vfmv_f_s_f32m1_f32(vmax_red); + + float reflect_scale = a_max_abs / 127.0f; + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + // --- Requant: a[l] = clamp(nearbyint(a[l] * reflect_scale_x), -128, 127) --- + vfloat32m4_t vscaled_lo = __riscv_vfmul_vf_f32m4(vf_lo, reflect_scale_0, vl16); + vfloat32m4_t vscaled_hi = __riscv_vfmul_vf_f32m4(vf_hi, reflect_scale_1, vl16); + + // fcvt.x rounds to nearest (using current rounding mode) + vint32m4_t vi_lo = __riscv_vfcvt_x_f_v_i32m4(vscaled_lo, vl16); + vint32m4_t vi_hi = __riscv_vfcvt_x_f_v_i32m4(vscaled_hi, vl16); + + // Clamp to [-128, 127] + vi_lo = __riscv_vmax_vx_i32m4(vi_lo, -128, vl16); + vi_lo = __riscv_vmin_vx_i32m4(vi_lo, 127, vl16); + vi_hi = __riscv_vmax_vx_i32m4(vi_hi, -128, vl16); + vi_hi = __riscv_vmin_vx_i32m4(vi_hi, 127, vl16); + + // Narrow i32 -> i16 -> i8 + vint16m2_t vi16_lo = __riscv_vncvt_x_x_w_i16m2(vi_lo, vl16); + vint16m2_t vi16_hi = __riscv_vncvt_x_x_w_i16m2(vi_hi, vl16); + vint8m1_t vi8_lo = __riscv_vncvt_x_x_w_i8m1(vi16_lo, vl16); + vint8m1_t vi8_hi = __riscv_vncvt_x_x_w_i8m1(vi16_hi, vl16); + + // Store d and qs directly into dst block + dst->d[i] = GGML_FP32_TO_FP16(reflect_scale); + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + __riscv_vse8_v_i8m1(dq, vi8_lo, vl16); + __riscv_vse8_v_i8m1(dq + 16, vi8_hi, vl16); + } + dst++; + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q8_0_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[0] % QK8_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrows_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + int i = 0; + for (; i < nrows_real; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q8_0_to_q8_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes scale gather + qs copy. +static int repack_q8_0_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK8_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q8_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q8_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Copy qs for each of the 32 rows (32 bytes per row) --- + { + for (int i = 0; i < 32; i++) { + const int8_t * sq = col_src[i * nblocks].qs; + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + + size_t len = QK8_0; + size_t idx = 0; + while (len > 0) { + size_t vl = __riscv_vsetvl_e8m2(len); + vint8m2_t vs = __riscv_vle8_v_i8m2(sq + idx, vl); + __riscv_vse8_v_i8m2(dq + idx, vs, vl); + idx += vl; + len -= vl; + } + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static void convert_mxfp4_to_5bit(const block_mxfp4 & src, spacemit_kernels::nrow_block_mxfp4<1> & dst) { + dst.e[0] = src.e; + + // Decode all 32 mxfp4 values to signed integers via kvalues_mxfp4 + int8_t vals[32]; + for (int j = 0; j < QK_MXFP4 / 2; j++) { + vals[j] = kvalues_mxfp4[src.qs[j] & 0xF]; + vals[j + QK_MXFP4 / 2] = kvalues_mxfp4[src.qs[j] >> 4]; + } + + // vals [b0, b1, b2, b3, ..., b30, b31] + // Pack abs into qs with reorder: [b0,b1]..[b14,b15]..[b30,b31] + for (int j = 0; j < QK_MXFP4 / 2; j++) { + uint8_t lo0 = static_cast(std::abs(vals[j * 2])); + uint8_t lo1 = static_cast(std::abs(vals[j * 2 + 1])); + dst.qs[j] = (lo0 & 0x0F) | ((lo1 & 0x0F) << 4); + } + + // Pack sign bits into qh[4] (32 bits total, 1 bit per weight) + // reorder: [0,1,2,...,15,16,17,...,31] after the qs reorder above + uint32_t sign_bits = 0; + for (int j = 0; j < 32; j++) { + if (vals[j] < 0) { + sign_bits |= (1u << j); + } + } + memcpy(dst.qh, &sign_bits, 4); +} + +static spacemit_kernels::nrow_block_mxfp4<32> make_block_mxfp4x32(spacemit_kernels::nrow_block_mxfp4<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_mxfp4<32> out; + GGML_ASSERT(QK_MXFP4 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.e[i] = in[i].e[0]; + } + + // qs: copy per-row 16 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * 16, in[i].qs, 16); + } + + // qh: copy per-row 4 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qh + i * 4, in[i].qh, 4); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_mxfp4<32> * dst = (spacemit_kernels::nrow_block_mxfp4<32> *) t->data; + const block_mxfp4 * src = (const block_mxfp4 *) data; + spacemit_kernels::nrow_block_mxfp4<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_MXFP4 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + convert_mxfp4_to_5bit(src[x + i * nblocks], dst_tmp[i]); + } + *dst++ = make_block_mxfp4x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static spacemit_kernels::nrow_block_q5_1<32> make_block_q5_1x32(spacemit_kernels::nrow_block_q5_1<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_1<32> out; + GGML_ASSERT(QK5_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + out.zp[i] = in[i].zp[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + QK5_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static spacemit_kernels::nrow_block_q5_0<32> make_block_q5_0x32(spacemit_kernels::nrow_block_q5_0<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_0<32> out; + GGML_ASSERT(QK5_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + QK5_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static int repack_q5_0_to_q5_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_0<32> * dst = (spacemit_kernels::nrow_block_q5_0<32> *) t->data; + const block_q5_0 * src = (const block_q5_0 *) data; + spacemit_kernels::nrow_block_q5_0<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_0 & s = src[x + i * nblocks]; + + dst_tmp[i].scales16[0] = s.d; + memcpy(dst_tmp[i].qs, s.qs, sizeof(dst_tmp[i].qs)); + memcpy(dst_tmp[i].qh, s.qh, sizeof(dst_tmp[i].qh)); + } + *dst++ = make_block_q5_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_1_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_1 * src = (const block_q5_1 *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_1 & s = src[x + i * nblocks]; + + float d = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + + if (d == 0.0f) { + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(std::fabs(m)); + dst_tmp[i].zp[0] = m < 0.0f ? 1 : 0; + memset(dst_tmp[i].qh, 0, sizeof(dst_tmp[i].qh)); + memset(dst_tmp[i].qs, m > 0.0f ? 0x11 : 0x00, sizeof(dst_tmp[i].qs)); + continue; + } + + float mid = std::nearbyintf(-m / d); + mid = std::min(31.0f, std::max(0.0f, mid)); + + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d); + dst_tmp[i].zp[0] = static_cast(mid); + + // qs: copy low 4 bits directly (same nibble packing) + memcpy(dst_tmp[i].qs, s.qs, QK5_1 / 2); + + // qh: copy 5th bit directly + memcpy(dst_tmp[i].qh, s.qh, 4); + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_k_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK5_1 == 8); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + + float d1 = d * sc; + float m1 = min * m; + + float mid = std::nearbyintf(m1 / d1); + mid = std::min(31.0f, std::max(0.0f, mid)); + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d1); + dst_tmp[i].zp[0] = static_cast(mid); + + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK5_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + + // Extract the 5th bit (qh) for this sub-block + // block_q5_K.qh[32]: for sub-block j, the 5th bit is at bit position j in qh[l] + // qs was reordered: dst_qs maps to src weights [0,16,1,17,...,15,31] + // So qh must follow the same reorder to stay aligned with qs + // dst qh[4] = 32 bits for 32 weights in the reordered layout: + // byte 0: weights 0..7 (from src_qh[0..7]) + // byte 1: weights 8..15 (from src_qh[8..15]) + // byte 2: weights 16..23 (from src_qh[16..23]) + // byte 3: weights 24..31 (from src_qh[24..31]) + const uint8_t * src_qh = src[x + i * nblocks].qh; + for (int bi = 0; bi < 4; bi++) { + uint8_t qh_byte = 0; + for (int k = 0; k < 8; k++) { + int src_idx = bi * 8 + k; + qh_byte |= ((src_qh[src_idx] >> j) & 1) << k; + } + dst_tmp[i].qh[bi] = qh_byte; + } + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +namespace ggml::cpu::riscv64_spacemit { + +template int repack(ggml_tensor *, const void *, size_t); + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_k_to_q2_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_k_to_q3_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_0_to_q4_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_0_to_q4_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_0_256_32_bl_ref(t, 32, data, data_size); +#else + //return repack_q4_0_to_q4_0_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_1_to_q4_1_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_1_256_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q6_k_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q6_k_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q8_0_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q8_0_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_0_to_q5_0_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_1_to_q5_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_k_to_q5_1_32_bl(t, 32, data, data_size); +} + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/repack.h b/ggml/src/ggml-cpu/spacemit/repack.h new file mode 100644 index 00000000000..950cbde7593 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml-common.h" +#include "ggml.h" + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(ggml_tensor * t, const void * data, size_t data_size); + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp new file mode 100644 index 00000000000..d2f89743622 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp @@ -0,0 +1,3178 @@ +#include "rvv_kernels.h" + +#include "common.h" +#include "ggml.h" +#include "ops.h" +#include "string.h" + +#include +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels::rvv { + +namespace { + +auto align_up(size_t value, size_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +static inline bool flash_attn_ext_supported_d_vlen1024_vf16(int64_t d) { + return d > 0 && d <= 128; +} + +static inline bool flash_attn_ext_supported_shape_vlen1024_vf16(int64_t DK, int64_t DV) { + return flash_attn_ext_supported_d_vlen1024_vf16(DK) && flash_attn_ext_supported_d_vlen1024_vf16(DV); +} + +static inline float reduce_sum_f32m4_vlen1024(vfloat32m4_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m4_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +static inline float reduce_sum_f32m2_vlen1024(vfloat32m2_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m2_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +// Adapted from ggml_v_expf_m2 in vec.h. This is accurate enough for softmax. +static inline vfloat32m2_t rvv_expf_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = + __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), u, vl), + u, vl); + + if (!__riscv_vcpop_m_b16(c, vl)) { + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + } + + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = + __riscv_vmerge_vvm_f32m2(__riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), c, vl); + return __riscv_vmerge_vvm_f32m2(r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), vl); +} + +static inline vfloat32m2_t rvv_tanh_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(x, vl); + const vfloat32m2_t neg_2_abs = __riscv_vfmul_vf_f32m2(abs_x, -2.0f, vl); + const vfloat32m2_t exp_term = rvv_expf_approx_f32m2(neg_2_abs, vl); + const vfloat32m2_t numerator = __riscv_vfsub_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t denominator = __riscv_vfadd_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t tanh_abs = __riscv_vfneg_v_f32m2(__riscv_vfdiv_vv_f32m2(numerator, denominator, vl), vl); + const vbool16_t neg_mask = __riscv_vmflt_vf_f32m2_b16(x, 0.0f, vl); + const vfloat32m2_t tanh_neg = __riscv_vfneg_v_f32m2(tanh_abs, vl); + return __riscv_vmerge_vvm_f32m2(tanh_abs, tanh_neg, neg_mask, vl); +} + +static void rvv_softcap_tanh_inplace_f32(float * dst, int64_t dst_stride, int64_t tile_rows, int64_t n, float softcap) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride) { + float * dst_row = dst; + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m2(remaining); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst_row, vl); + v = rvv_tanh_approx_f32m2(v, vl); + v = __riscv_vfmul_vf_f32m2(v, softcap, vl); + __riscv_vse32_v_f32m2(dst_row, v, vl); + dst_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_softmax_exp_inplace_f32(float * dst, int64_t n, float max_value) { + float row_sum = 0.0f; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst, vl); + v = __riscv_vfsub_vf_f32m2(v, max_value, vl); + v = rvv_expf_approx_f32m2(v, vl); + __riscv_vse32_v_f32m2(dst, v, vl); + row_sum += reduce_sum_f32m2_vlen1024(v, vl); + dst += vl; + n -= vl; + } + return row_sum; +} + +static inline float rvv_add_max_inplace_f32(float * dst, const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline float rvv_softcap_add_max_inplace_f32(float * dst, const float * src, int64_t n, float softcap) { + if (softcap == 0.0f) { + return rvv_add_max_inplace_f32(dst, src, n); + } + + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t vdst = __riscv_vle32_v_f32m2(dst, vl); + vfloat32m2_t vsrc = __riscv_vle32_v_f32m2(src, vl); + vdst = rvv_tanh_approx_f32m2(vdst, vl); + vdst = __riscv_vfmul_vf_f32m2(vdst, softcap, vl); + vdst = __riscv_vfadd_vv_f32m2(vdst, vsrc, vl); + __riscv_vse32_v_f32m2(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m2_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline void rvv_zero_f32(float * dst, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t z = __riscv_vfmv_v_f_f32m4(0.0f, vl); + __riscv_vse32_v_f32m4(dst, z, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_scale_f32(float * dst, float scale, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t v = __riscv_vle32_v_f32m4(dst, vl); + v = __riscv_vfmul_vf_f32m4(v, scale, vl); + __riscv_vse32_v_f32m4(dst, v, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_add_inplace_f32(float * dst, + int64_t dst_stride, + const float * src, + int64_t src_stride, + int64_t tile_rows, + int64_t n) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride, src += src_stride) { + int64_t remaining = n; + float * dst_row = dst; + const float * src_row = src; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst_row, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src_row, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst_row, vdst, vl); + dst_row += vl; + src_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_max_f32(const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t v = __riscv_vle32_v_f32m4(src, vl); + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(v, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + src += vl; + n -= vl; + } + return max_val; +} + +static void rvv_pack_f32_as_scaled_f16(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + _Float16 * dst_row_ptr = (_Float16 *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + const vfloat16m2_t v16 = __riscv_vfncvt_f_f_w_f16m2(v32, vl); + __riscv_vse16_v_f16m2(dst_row_ptr, v16, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f16_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const _Float16 * row_ptr = (const _Float16 *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e16m2(remaining); + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(row_ptr, vl); + vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f32_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float * scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale[tq], vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static inline void rvv_transposed_s32_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e32, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vle32.v v1, (s2) \n\t" + "sh2add s2, t0, s2 \n\t" + "vle32.v v2, (s3) \n\t" + "sh2add s3, t0, s3 \n\t" + "vle32.v v3, (s4) \n\t" + "sh2add s4, t0, s4 \n\t" + "vle32.v v4, (s5) \n\t" + "sh2add s5, t0, s5 \n\t" + "vle32.v v5, (s6) \n\t" + "sh2add s6, t0, s6 \n\t" + "vle32.v v6, (s7) \n\t" + "sh2add s7, t0, s7 \n\t" + "vle32.v v7, (s8) \n\t" + "sh2add s8, t0, s8 \n\t" + "vssseg8e32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 32 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vsse32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 4 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_transposed_s16_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e16, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vle16.v v1, (s2) \n\t" + "sh1add s2, t0, s2 \n\t" + "vle16.v v2, (s3) \n\t" + "sh1add s3, t0, s3 \n\t" + "vle16.v v3, (s4) \n\t" + "sh1add s4, t0, s4 \n\t" + "vle16.v v4, (s5) \n\t" + "sh1add s5, t0, s5 \n\t" + "vle16.v v5, (s6) \n\t" + "sh1add s6, t0, s6 \n\t" + "vle16.v v6, (s7) \n\t" + "sh1add s7, t0, s7 \n\t" + "vle16.v v7, (s8) \n\t" + "sh1add s8, t0, s8 \n\t" + "vssseg8e16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 16 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vsse16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 2 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_qk_dot_tile_f16_x1(float * dst, + const _Float16 * q_row, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc = __riscv_vfwmacc_vf_f32m2(acc, q_row[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst, acc, vl); +} + +static inline void rvv_qk_dot_tile_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const _Float16 * q0, + const _Float16 * q1, + const _Float16 * q2, + const _Float16 * q3, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc0 = __riscv_vfwmacc_vf_f32m2(acc0, q0[d], k_vec, vl); + acc1 = __riscv_vfwmacc_vf_f32m2(acc1, q1[d], k_vec, vl); + acc2 = __riscv_vfwmacc_vf_f32m2(acc2, q2[d], k_vec, vl); + acc3 = __riscv_vfwmacc_vf_f32m2(acc3, q3[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst0, acc0, vl); + __riscv_vse32_v_f32m2(dst1, acc1, vl); + __riscv_vse32_v_f32m2(dst2, acc2, vl); + __riscv_vse32_v_f32m2(dst3, acc3, vl); +} + +static inline void rvv_pv_accumulate_f16_x1(float * dst, + const float * prob, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_pv_accumulate_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const float * prob0, + const float * prob1, + const float * prob2, + const float * prob3, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc0 = __riscv_vle32_v_f32m4(dst0 + d_off, vl); + vfloat32m4_t acc1 = __riscv_vle32_v_f32m4(dst1 + d_off, vl); + vfloat32m4_t acc2 = __riscv_vle32_v_f32m4(dst2 + d_off, vl); + vfloat32m4_t acc3 = __riscv_vle32_v_f32m4(dst3 + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc0 = __riscv_vfmacc_vf_f32m4(acc0, prob0[tk], v32, vl); + acc1 = __riscv_vfmacc_vf_f32m4(acc1, prob1[tk], v32, vl); + acc2 = __riscv_vfmacc_vf_f32m4(acc2, prob2[tk], v32, vl); + acc3 = __riscv_vfmacc_vf_f32m4(acc3, prob3[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst0 + d_off, acc0, vl); + __riscv_vse32_v_f32m4(dst1 + d_off, acc1, vl); + __riscv_vse32_v_f32m4(dst2 + d_off, acc2, vl); + __riscv_vse32_v_f32m4(dst3 + d_off, acc3, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_qk_dot_tile(float * dst, + const float * q_row, + const float * k_pack, + int64_t dk, + int64_t kv_tile, + float scale) { + const size_t vl = __riscv_vsetvl_e32m4(kv_tile); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat32m4_t k_vec = __riscv_vle32_v_f32m4(k_pack + d * kv_tile, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, q_row[d] * scale, k_vec, vl); + } + + __riscv_vse32_v_f32m4(dst, acc, vl); +} + +static inline void rvv_pv_accumulate(float * dst, + const float * prob, + const float * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e32m4(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat32m4_t v_vec = __riscv_vle32_v_f32m4(v_pack + tk * dv + d_off, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v_vec, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static void permute_transpose_impl(const ggml_tensor * src0, + ggml_tensor * dst, + int64_t batch, + int64_t m, + int64_t n, + int64_t batch_stride, + int64_t m_src_stride, + int64_t n_src_stride, + int64_t n_dst_stride, + int ith, + int nth) { + GGML_ASSERT(n_src_stride == sizeof(int32_t) || n_src_stride == sizeof(int16_t)); + + if (n_src_stride == sizeof(int32_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else if (n_src_stride == sizeof(int16_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else { + GGML_ABORT("not implemented"); + } +} + +template +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow(float ** pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + float ** sinks, + float ** dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK, + void * tcm_buffer, + size_t tcm_buffer_size) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + float S[QLEN] = { 0.0f }; // sum + float M[QLEN] = { -INFINITY }; // maximum KQ value + + _Float16 * kq16_buffer = (_Float16 *) tcm_buffer; + _Float16 * qv_buffer = kq16_buffer + QLEN * DV; + const size_t qkv_temp_buffer_size = (QLEN * DV + QLEN * DK) * sizeof(_Float16); + char * kv_tile_buffer = (char *) (qv_buffer + QLEN * DK); + + { + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + for (int64_t i = 0; i < QLEN; ++i) { + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq[i], DK), DK); + __riscv_vse16_v_f16m2(qv_buffer + i * DK, Q_q_v, DK); + } + } + + const uintptr_t scratch_addr = reinterpret_cast(kv_tile_buffer); + const size_t scratch_size = tcm_buffer_size > qkv_temp_buffer_size ? tcm_buffer_size - qkv_temp_buffer_size : 0; + const uintptr_t kq_tile_addr = align_up(scratch_addr, alignof(float)); + const size_t scratch_prefix = kq_tile_addr - scratch_addr; + const size_t packed_tile_size = + QLEN * sizeof(float) + DK * sizeof(_Float16) + DV * sizeof(_Float16) + sizeof(float); + const int64_t max_ic_tile_step = ((int64_t) __riscv_vsetvlmax_e16m1()) & ~((int64_t) 7); + const int64_t max_fit_by_tcm = + scratch_size > scratch_prefix ? (int64_t) ((scratch_size - scratch_prefix) / packed_tile_size) : 0; + const int64_t ic_tile_step = std::min(max_ic_tile_step, max_fit_by_tcm) & ~((int64_t) 7); + + const uintptr_t k_tile_addr = kq_tile_addr + QLEN * ic_tile_step * sizeof(float); + const uintptr_t v_tile_addr = k_tile_addr + DK * ic_tile_step * sizeof(_Float16); + const uintptr_t mv_tile_addr = v_tile_addr + ic_tile_step * DV * sizeof(_Float16); + + if (ic_tile_step >= 8) { + float * kq_tile_buffer = reinterpret_cast(kq_tile_addr); + _Float16 * k_tile_pack = reinterpret_cast<_Float16 *>(k_tile_addr); + _Float16 * v_tile_pack = reinterpret_cast<_Float16 *>(v_tile_addr); + float * mv_tile_pack = reinterpret_cast(mv_tile_addr); + + const int64_t k_tile_byte_stride = ic_tile_step * (int64_t) sizeof(_Float16); + + int64_t ic_step = 0; + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + if (mv != -INFINITY) { + const _Float16 * k_data = (const _Float16 *) (k_data_row + ic * nbk1); + const _Float16 * v_data = (const _Float16 *) (v_data_row + ic * nbv1); + + const vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2(k_data, DK); + const vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2(v_data, DV); + __riscv_vsse16_v_f16m2(k_tile_pack + ic_step, k_tile_byte_stride, k_data_v, DK); + __riscv_vse16_v_f16m2(v_tile_pack + ic_step * DV, v_data_v, DV); + mv_tile_pack[ic_step] = mv; + ic_step++; + } + + if (ic_step > 0 && (ic_step == ic_tile_step || ic == (nek1 - 1))) { + if constexpr (QLEN == 4) { + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc2 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc3 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + qk_acc2 = __riscv_vfwmacc_vf_f32m2(qk_acc2, qv_buffer[2 * DK + d], k_vec, qk_vl); + qk_acc3 = __riscv_vfwmacc_vf_f32m2(qk_acc3, qv_buffer[3 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + qk_acc2 = __riscv_vfmul_vf_f32m2(qk_acc2, scale, qk_vl); + qk_acc3 = __riscv_vfmul_vf_f32m2(qk_acc3, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 2 * ic_tile_step, qk_acc2, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 3 * ic_tile_step, qk_acc3, qk_vl); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + } + + for (int i = 0; i < QLEN; ++i) { + float * row_ptr = kq_tile_buffer + i * ic_tile_step; + const float tile_max = + rvv_softcap_add_max_inplace_f32(row_ptr, mv_tile_pack, ic_step, logit_softcap); + + const float Mold = M[i]; + + if (tile_max > Mold) { + const float ms = expf(Mold - tile_max); + M[i] = tile_max; + S[i] *= ms; + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, (_Float16) ms, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + } + + S[i] += rvv_softmax_exp_inplace_f32(row_ptr, ic_step, M[i]); + } + + if constexpr (QLEN == 4) { + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + vfloat16m2_t pv_acc2 = __riscv_vle16_v_f16m2(kq16_buffer + 2 * DV, DV); + vfloat16m2_t pv_acc3 = __riscv_vle16_v_f16m2(kq16_buffer + 3 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + pv_acc2 = + __riscv_vfmacc_vf_f16m2(pv_acc2, (_Float16) kq_tile_buffer[2 * ic_tile_step + tk], v16, DV); + pv_acc3 = + __riscv_vfmacc_vf_f16m2(pv_acc3, (_Float16) kq_tile_buffer[3 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 2 * DV, pv_acc2, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 3 * DV, pv_acc3, DV); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + } + + ic_step = 0; + } + } + } else { + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + const char * k_data = k_data_row + ic * nbk1; + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t k_data_v; + vfloat16m2_t v_data_v; + + if (mv != -INFINITY) { + k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + } else { + continue; + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t Q_q_v = __riscv_vle16_v_f16m2(qv_buffer + i * DK, DK); + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + s = s * scale; + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + s += mv; + + const float Mold = M[i]; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + if (s > M[i]) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M[i] = s; + ms = expf(Mold - M[i]); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M[i]); + } + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + S[i] = S[i] * ms + vs; // scale and increment sum with partial sum + } + } + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks[i]) { + const float s = *(sinks[i]); + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[i]) { + ms = expf(M[i] - s); + M[i] = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M[i]); + } + + S[i] = S[i] * ms + vs; + } + + // V /= S + const float S_inv = S[i] == 0.0f ? 0.0f : 1.0f / S[i]; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst[i], VKQ32_v, DV); + } +} + +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1(const float * pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + const float * sinks, + float * dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq, DK), DK); + + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + if (mv == -INFINITY) { + continue; + } + + const char * k_data = k_data_row + ic * nbk1; + + vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + + s = s * scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + + S = S * ms + vs; // scale and increment sum with partial sum + } + + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks) { + const float s = *sinks; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + M = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f / S; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst, VKQ32_v, DV); +} + +} // namespace + +void memcpy1d(void * dst, const void * src, int64_t size) { + size_t byte_size_all = size; + size_t vlen = __riscv_vlenb() * 8; + if (vlen == 256) { + // 1024 bytes + __asm__ volatile( + // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1"); + } else if (vlen == 1024) { + // 2048 bytes + __asm__ volatile( + // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "addi t2, %[s], 1024 \n\t" + "addi t3, %[d], 1024 \n\t" + "li t5, 2048 \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t5 \n\t" + "vle8.v v8, (t2) \n\t" + "add t2, t2, t5 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t5 \n\t" + "vse8.v v8, (t3) \n\t" + "add t3, t3, t5 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m2, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t3", "t5"); + } else { + __asm__ volatile( + // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t4", "t3"); + } +} + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size) { + for (int64_t i = 0; i < tile_rows; ++i) { + memcpy1d((char *) dst + i * dst_stride, (const char *) src + i * src_stride, size); + } +} + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + // parallelize by q rows using ggml_vec_dot_f32 + + float scale = *((float *) dst->op_params + 0); + float max_bias = *((float *) dst->op_params + 1); + float logit_softcap = *((float *) dst->op_params + 2); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int KV_row_size = DK * sizeof(_Float16) + DV * sizeof(_Float16); + + int ith = params->ith; + int ir_step = 1; + for (int ir = ir0; ir < ir1; ir += ir_step) { + // q indices + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + const int iq3_1 = (ir + 1) / (neq2 * neq1); + const int iq2_1 = (ir + 1 - iq3_1 * neq2 * neq1) / neq1; + const int iq1_1 = (ir + 1 - iq3_1 * neq2 * neq1 - iq2_1 * neq1); + + const int iq3_2 = (ir + 2) / (neq2 * neq1); + const int iq2_2 = (ir + 2 - iq3_2 * neq2 * neq1) / neq1; + const int iq1_2 = (ir + 2 - iq3_2 * neq2 * neq1 - iq2_2 * neq1); + + const int iq3_3 = (ir + 3) / (neq2 * neq1); + const int iq2_3 = (ir + 3 - iq3_3 * neq2 * neq1) / neq1; + const int iq1_3 = (ir + 3 - iq3_3 * neq2 * neq1 - iq2_3 * neq1); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + const ggml_fp16_t * mp = + mask ? (ggml_fp16_t *) ((char *) mask->data + iq1 * mask->nb[1] + (iq2 % mask->ne[2]) * mask->nb[2] + + (iq3 % mask->ne[3]) * mask->nb[3]) : + NULL; + + const bool mp_equal_2 = iq1_1 == iq1 && (iq2 % mask->ne[2]) == (iq2_1 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_1 % mask->ne[3]); + + const bool mp_equal_4 = mp_equal_2 && iq1_2 == iq1 && (iq2 % mask->ne[2]) == (iq2_2 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_2 % mask->ne[3]) && iq1_3 == iq1 && + (iq2 % mask->ne[2]) == (iq2_3 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_3 % mask->ne[3]); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + const int ik3_1 = iq3_1 / rk3; + const int ik2_1 = iq2_1 / rk2; + + const int ik3_2 = iq3_2 / rk3; + const int ik2_2 = iq2_2 / rk2; + + const int ik3_3 = iq3_3 / rk3; + const int ik2_3 = iq2_3 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const int iv3_1 = iq3_1 / rv3; + const int iv2_1 = iq2_1 / rv2; + + const int iv3_2 = iq3_2 / rv3; + const int iv2_2 = iq2_2 / rv2; + + const int iv3_3 = iq3_3 / rv3; + const int iv2_3 = iq2_3 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + + std::array pq_buffer; + std::array sinks_buffer; + std::array dst_buffer; + + if (tcm_buffer != nullptr && 4 * KV_row_size < tcm_buffer_size && ir < (ir1 - 3) && mp_equal_4 && + ik3_3 == ik3 && ik2_3 == ik2 && iv3_3 == iv3 && iv2_3 == iv2 && ik3_2 == ik3 && ik2_2 == ik2 && + iv3_2 == iv3 && iv2_2 == iv2 && ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 4; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + pq_buffer[2] = (float *) ((char *) q->data + (iq1_2 * nbq1 + iq2_2 * nbq2 + iq3_2 * nbq3)); + pq_buffer[3] = (float *) ((char *) q->data + (iq1_3 * nbq1 + iq2_3 * nbq2 + iq3_3 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + sinks_buffer[2] = sinks ? ((float *) ((char *) sinks->data)) + iq2_2 : nullptr; + sinks_buffer[3] = sinks ? ((float *) ((char *) sinks->data)) + iq2_3 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + dst_buffer[2] = (float *) ((char *) dst->data + (iq3_2 * ne2 * ne1 + iq2_2 + iq1_2 * ne1) * nb1); + dst_buffer[3] = (float *) ((char *) dst->data + (iq3_3 * ne2 * ne1 + iq2_3 + iq1_3 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<4>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else if (tcm_buffer != nullptr && 2 * KV_row_size < tcm_buffer_size && ir < (ir1 - 1) && mp_equal_2 && + ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 2; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<2>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else { + ir_step = 1; + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1( // + pq, // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks ? ((float *) ((char *) sinks->data)) + h : nullptr, // + (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK); + } + } +} + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + float * param_list = (float *) dst->op_params; + float scale = param_list[0]; + float max_bias = param_list[1]; + float logit_softcap = param_list[2]; + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + // Per-thread scratch layout: + // Q_f32: Q_TILE_SZ * DK + // KQ: Q_TILE_SZ * KV_TILE_SZ + // mask32: Q_TILE_SZ * KV_TILE_SZ + // VKQ32: Q_TILE_SZ * DV + // V32: KV_TILE_SZ * DV + // K_f32: DK * KV_TILE_SZ (transposed K tile) + float * base = (float *) params->wdata + ith * (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + + KV_TILE_SZ * DV + KV_TILE_SZ * DK + CACHE_LINE_SIZE_F32); + const size_t base_size = + (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + KV_TILE_SZ * DV + KV_TILE_SZ * DK) * + sizeof(float) + + CACHE_LINE_SIZE_F32; + + if (base_size <= tcm_buffer_size && tcm_buffer != nullptr) { + base = (float *) tcm_buffer; + } + + float S_M_Buf[Q_TILE_SZ * 2]; // buffer to hold S, M, bias for one tile to reduce register pressure in main loop + float * S = S_M_Buf; + float * M = S_M_Buf + Q_TILE_SZ; + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int) (ir1 - ir), (int) (neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + for (int i = 0; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + float * Q_f32 = base; + float * KQ = (float *) ((char *) base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; + _Float16 * Q_f16 = (_Float16 *) Q_f32; + _Float16 * V_f16 = (_Float16 *) V32; + _Float16 * K_f16 = (_Float16 *) K_f32; + + rvv_zero_f32(VKQ32, Q_TILE_SZ * DV); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + if (kv_type == GGML_TYPE_F16) { + rvv_pack_f32_as_scaled_f16((uint8_t *) Q_f16, DK * sizeof(_Float16), (uint8_t *) pq, nbq1, tile_rows, DK, + scale); + } else { + memcpy2d(Q_f32, DK * sizeof(float), pq, nbq1, tile_rows, DK * sizeof(float)); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int) std::min((int64_t) KV_TILE_SZ, nek1 - ic); + + rvv_zero_f32(K_f32, DK * KV_TILE_SZ); + rvv_zero_f32(V32, KV_TILE_SZ * DV); + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + const ggml_fp16_t * mp_row = + (const ggml_fp16_t *) ((const char *) mask->data + iq1 * mask->nb[1] + + (iq2 % mask->ne[2]) * mask->nb[2] + (iq3 % mask->ne[3]) * mask->nb[3]); + rvv_pack_scaled_f16_as_f32(mask32, KV_TILE_SZ * sizeof(float), mp_row + ic, mask->nb[1], tile_rows, + kv_tile, slope); + + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = 0; tk < kv_tile; tk++) { + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + + if (can_skip) { + continue; + } + } + + if (kv_type == GGML_TYPE_F16) { + rvv_transposed_s16_mn_to_nm((int8_t *) K_f16, KV_TILE_SZ * sizeof(_Float16), + (int8_t *) k->data + ic * nbk1 + ik2 * nbk2 + ik3 * nbk3, nbk1, kv_tile, + DK); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + rvv_qk_dot_tile_f16_x4(KQ + (tq + 0) * KV_TILE_SZ, KQ + (tq + 1) * KV_TILE_SZ, + KQ + (tq + 2) * KV_TILE_SZ, KQ + (tq + 3) * KV_TILE_SZ, + Q_f16 + (tq + 0) * DK, Q_f16 + (tq + 1) * DK, Q_f16 + (tq + 2) * DK, + Q_f16 + (tq + 3) * DK, K_f16, DK, kv_tile); + } + for (; tq < tile_rows; ++tq) { + rvv_qk_dot_tile_f16_x1(KQ + tq * KV_TILE_SZ, Q_f16 + tq * DK, K_f16, DK, kv_tile); + } + } else { + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *) k->data + (ic + tk) * nbk1 + ik2 * nbk2 + ik3 * nbk3; + float * k_col = K_f32 + tk; + const float * k_src = (const float *) k_data; + for (int64_t dk = 0; dk < DK; ++dk) { + k_col[dk * KV_TILE_SZ] = k_src[dk]; + } + } + + for (int tq = 0; tq < tile_rows; ++tq) { + rvv_qk_dot_tile(KQ + tq * KV_TILE_SZ, Q_f32 + tq * DK, K_f32, DK, KV_TILE_SZ, scale); + } + } + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + } + + if (logit_softcap != 0.0f) { + rvv_softcap_tanh_inplace_f32(KQ, KV_TILE_SZ, tile_rows, KV_TILE_SZ, logit_softcap); + } + + if (mask) { + rvv_add_inplace_f32(KQ, KV_TILE_SZ, mask32, KV_TILE_SZ, tile_rows, KV_TILE_SZ); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < tile_rows; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + const float tile_max = rvv_max_f32(kq_row, KV_TILE_SZ); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + S[tq] *= ms; + } + M[tq] = Mnew; + + S[tq] += rvv_softmax_exp_inplace_f32(kq_row, KV_TILE_SZ, Mnew); + } + + // Pack V as contiguous [KV_TILE_SZ][DV]. + if (kv_type == GGML_TYPE_F16) { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V_f16, DV * sizeof(_Float16), v_data, nbv1, kv_tile, DV * sizeof(_Float16)); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + if (skip[tq + 0] || skip[tq + 1] || skip[tq + 2] || skip[tq + 3]) { + for (int i = 0; i < 4; ++i) { + if (!skip[tq + i]) { + rvv_pv_accumulate_f16_x1(VKQ32 + (tq + i) * DV, KQ + (tq + i) * KV_TILE_SZ, V_f16, + KV_TILE_SZ, DV); + } + } + continue; + } + + rvv_pv_accumulate_f16_x4(VKQ32 + (tq + 0) * DV, VKQ32 + (tq + 1) * DV, VKQ32 + (tq + 2) * DV, + VKQ32 + (tq + 3) * DV, KQ + (tq + 0) * KV_TILE_SZ, + KQ + (tq + 1) * KV_TILE_SZ, KQ + (tq + 2) * KV_TILE_SZ, + KQ + (tq + 3) * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + for (; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate_f16_x1(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + } + } else { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V32, DV * sizeof(float), v_data, nbv1, kv_tile, DV * sizeof(float)); + + for (int tq = 0; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V32, KV_TILE_SZ, DV); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *) ((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + } else { + vs = expf(s - M[tq]); + } + + float S_temp = S[tq] * ms + vs; + S[tq] = S_temp == 0.0f ? 0.0f : 1.0f / S_temp; + } + } else { + for (int tq = 0; tq < tile_rows; tq++) { + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + S[tq] = S_inv; + } + } + + float * dst_ptr = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + (iq1) *ne1) * nb1); + rvv_pack_scaled_f32_as_f32(dst_ptr, nb1 * ne1, VKQ32, DV * sizeof(float), tile_rows, DV, S); + + ir += tile_rows; + } +} + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template +void quantize_a_nrow_i8_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk])); + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + int16_t a_sum = 0; + for (size_t bk = 0; bk < blk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row] = -a_sum; + } + } +} + +template +void quantize_a_nrow_i8_hp_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + const size_t subblk_count = blk_len / k_subblk_len; + + GGML_ASSERT(blk_len == 256); + + float scale_temp[8] = { 0.0f }; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * MB_ROWS; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + + float scale_avg = 0.0f; + for (size_t kk = 0; kk < subblk_count; kk++) { + float max_abs_a = 0.0f; + for (size_t row = 0; row < MB_ROWS; row++) { + for (size_t bk = 0; bk < k_subblk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk + kk * k_subblk_len])); + } + } + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + float scale_factor = 1.0f / scale_avg; + + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + scale_avg_ptr[0] = scale_avg; + + for (size_t kk = 0; kk < subblk_count; kk++) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * MB_ROWS); + + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + const float rep_scale_a = 1.0f / scale_temp[kk]; + + for (size_t row = 0; row < MB_ROWS; row++) { + int16_t a_sum = 0; + for (size_t bk = 0; bk < k_subblk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk + kk * k_subblk_len] * rep_scale_a), + -128.0f, 127.0f)); + quant_a_blk[row * k_subblk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * subblk_count + kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + } + } + } +} + +template +void quantize_a_nrow_i8k_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_sum_size = 256 / 16; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * a_sum_size * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_a = 0.0f; + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + float ax = std::abs(a_ptr[row * count_k + k + bk]); + if (ax > max_abs_a) { + max_abs_a = ax; + max_a = a_ptr[row * count_k + k + bk]; + } + } + + if (!max_abs_a) { + scale_a_ptr[row] = 0; + for (size_t bki = 0; bki < a_sum_size; bki++) { + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + quant_a_blk[row * blk_len + bk] = 0; + } + a_sum_ptr[row * a_sum_size + bki] = 0; + } + continue; + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + for (size_t bki = 0; bki < a_sum_size; bki++) { + int16_t a_sum = 0; + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * a_sum_size + bki] = -a_sum; + } + } + } +} + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } else { + quantize_a_nrow_i8_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + mi * count_k + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false); + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * 4; + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + vfloat32m1_t v_a0_abs = __riscv_vfabs_v_f32m1(v_a0, vl); + vfloat32m1_t v_a1_abs = __riscv_vfabs_v_f32m1(v_a1, vl); + vfloat32m1_t v_a2_abs = __riscv_vfabs_v_f32m1(v_a2, vl); + vfloat32m1_t v_a3_abs = __riscv_vfabs_v_f32m1(v_a3, vl); + + vfloat32m1_t v_max_abs = __riscv_vfmax_vv_f32m1(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a0_scale = __riscv_vfmul_vf_f32m1(v_a0, rep_scale_a, vl); + vfloat32m1_t v_a1_scale = __riscv_vfmul_vf_f32m1(v_a1, rep_scale_a, vl); + vfloat32m1_t v_a2_scale = __riscv_vfmul_vf_f32m1(v_a2, rep_scale_a, vl); + vfloat32m1_t v_a3_scale = __riscv_vfmul_vf_f32m1(v_a3, rep_scale_a, vl); + vint16mf2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a0_scale, vl); + vint16mf2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a1_scale, vl); + vint16mf2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a2_scale, vl); + vint16mf2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a3_scale, vl); + vint8mf4_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a0_quant, vl); + vint8mf4_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a1_quant, vl); + vint8mf4_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a2_quant, vl); + vint8mf4_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + vfloat32m4_t v_a0_abs = __riscv_vfabs_v_f32m4(v_a0, vl); + vfloat32m4_t v_a1_abs = __riscv_vfabs_v_f32m4(v_a1, vl); + vfloat32m4_t v_a2_abs = __riscv_vfabs_v_f32m4(v_a2, vl); + vfloat32m4_t v_a3_abs = __riscv_vfabs_v_f32m4(v_a3, vl); + + vfloat32m4_t v_max_abs = __riscv_vfmax_vv_f32m4(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a0_scale = __riscv_vfmul_vf_f32m4(v_a0, rep_scale_a, vl); + vfloat32m4_t v_a1_scale = __riscv_vfmul_vf_f32m4(v_a1, rep_scale_a, vl); + vfloat32m4_t v_a2_scale = __riscv_vfmul_vf_f32m4(v_a2, rep_scale_a, vl); + vfloat32m4_t v_a3_scale = __riscv_vfmul_vf_f32m4(v_a3, rep_scale_a, vl); + vint16m2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16m2(v_a0_scale, vl); + vint16m2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16m2(v_a1_scale, vl); + vint16m2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16m2(v_a2_scale, vl); + vint16m2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16m2(v_a3_scale, vl); + vint8m1_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a0_quant, vl); + vint8m1_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a1_quant, vl); + vint8m1_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a2_quant, vl); + vint8m1_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits, can process 32 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + // vlen = 256 bits, can process 8 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8k_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_nrow_block_stride = a_blk_stride * 4; + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else if (vlenb == 32) { + // vlen = 256 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else { + quantize_a_nrow_i8k_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = src0->ne[2] * src0->ne[3]; + int64_t m = src0->ne[1]; + int64_t n = src0->ne[0]; + + int64_t batch_stride = src0->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = n_src_stride * m; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = dst->ne[2] * dst->ne[3]; + int64_t n = dst->ne[1]; + int64_t m = dst->ne[0]; + + int64_t batch_stride = dst->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = dst->nb[1]; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + auto src0_rows = ggml_nrows(src0); + auto src1_rows = ggml_nrows(src1); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(T)); + GGML_ASSERT(nb00 == sizeof(T)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + auto compute_func_vv = [&](int64_t blk_len, int64_t r, T * src0_ptr, T * src1_ptr, T * dst_ptr) { + int64_t idx = 0; + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfadd_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfadd_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfsub_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfsub_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfmul_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfmul_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfdiv_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfdiv_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + }; + + if (src0_rows == src1_rows && src0_rows == 1 && ne00 == ne10) { + int64_t task_per_thread = (ne00 + nth - 1) / nth; + int64_t task_begin = ith * task_per_thread; + int64_t task_end = std::min((ith + 1) * task_per_thread, ne00); + + T * dst_ptr = ((T *) dst->data) + task_begin; + T * src0_ptr = ((T *) src0->data) + task_begin; + T * src1_ptr = ((T *) src1->data) + task_begin; + + compute_func_vv(task_end - task_begin, 0, src0_ptr, src1_ptr, dst_ptr); + } else if (ne10 > 1) { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + // src1 is broadcastable across src0 and dst in i1, i2, i3 + for (int64_t r = 0; r < ne00; r += ne10) { + compute_func_vv(ne10, r, src0_ptr, src1_ptr, dst_ptr); + } + } + } else { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + T rhs_scalar = src1_ptr[0]; + int64_t blk_len = ne00; + int64_t r = 0; + + for (size_t vl; blk_len > 0; blk_len -= vl, r += vl) { + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfadd_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfadd_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfsub_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfsub_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfmul_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfmul_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfdiv_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfdiv_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + int64_t n_task = ne01 * ne02 * ne03; + int64_t task_per_thread = (n_task + nth - 1) / nth; + int64_t ir_start = ith * task_per_thread; + int64_t ir_end = std::min(ir_start + task_per_thread, n_task); + + for (int64_t ir = ir_start; ir < ir_end; ir++) { + const int64_t i3 = ir / (ne02 * ne01); + const int64_t i2 = (ir - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (ir - i3 * ne02 * ne01 - i2 * ne01); + + T * src_row = (T *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + T * dst_row = (T *) ((char *) op->data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + float row_sum = 0; + + if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const float * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t vec = __riscv_vle32_v_f32m4(p_data, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e16m2(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const _Float16 * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e16m2(length); + vfloat16m2_t vec_f16 = __riscv_vle16_v_f16m2(p_data, vl); + vfloat32m4_t vec_f32 = __riscv_vfwcvt_f_f_v_f32m4(vec_f16, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec_f32, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else { + GGML_ABORT("fatal error"); + } + + dst_row[0] = row_sum; + } +} + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t nrows = ggml_nrows(src0); + int64_t nrows_per_thread = (nrows + nth - 1) / nth; + int64_t ir_start = ith * nrows_per_thread; + int64_t ir_end = std::min(ir_start + nrows_per_thread, nrows); + + if (src0->ne[0] == 1) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + T src_scalar = src_row[0]; + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vmv_v_x_i32m4(src_scalar, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vmv_v_x_i16m4(src_scalar, vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else if (src0->ne[0] == dst->ne[0]) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_row + idx, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_row + idx), vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else { + GGML_ABORT("fatal error"); + } +} + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t total_batches = ne2 * ne3; + const int64_t batches_per_thread = (total_batches + nth - 1) / nth; + const int64_t batch_start = ith * batches_per_thread; + const int64_t batch_end = std::min(batch_start + batches_per_thread, total_batches); + + for (int64_t b = batch_start; b < batch_end; b++) { + const int64_t i3 = b / ne2; + const int64_t i2 = b % ne2; + + T * src_base = (T *) ((char *) src0->data + i2 * src0->nb[2] + i3 * src0->nb[3]); + T * dst_batch = (T *) ((char *) dst->data + i2 * dst->nb[2] + i3 * dst->nb[3]); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + T * dst_ptr = (T *) ((char *) dst_batch + i1 * dst->nb[1]); + int64_t length = ne0; + int64_t idx = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_base + idx, vl); + __riscv_vse32_v_i32m4(dst_ptr + idx, vec, vl); + idx += vl; + length -= vl; + } else if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_base + idx), vl); + __riscv_vse16_v_i16m4((dst_ptr + idx), vec, vl); + idx += vl; + length -= vl; + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(op) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + // rows per thread + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i / (ne11 * ne10); + const int64_t i11 = (i - i12 * ne11 * ne10) / ne10; + const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + memcpy1d(((char *) dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3) + cr0 * sizeof(T), + ((char *) src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03) + cr0 * sizeof(T), + (cr1 - cr0) * sizeof(T)); + } +} + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim == 0 && nb0 == sizeof(float) && nb1 == sizeof(float) * (ne00 + ne10)); + + const int64_t nr = ggml_nrows(dst); + const int64_t nc = ne0; + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + int64_t o[4] = { 0, 0, 0, 0 }; + o[dim] = src0->ne[dim]; + const float * x; + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i3 = i / (ne02 * ne01); + const int64_t i2 = (i - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (i - i3 * ne02 * ne01 - i2 * ne01); + + for (int i0 = cr0; i0 < cr1; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *) src0->data + (i0) *nb00 + (i1) *nb01 + (i2) *nb02 + (i3) *nb03); + } else { + x = (const float *) ((const char *) src1->data + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); + } + + float * y = (float *) ((char *) dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + + *y = *x; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows<_Float16>(const ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +} // namespace spacemit_kernels::rvv diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.h b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h new file mode 100644 index 00000000000..edddf957c21 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h @@ -0,0 +1,95 @@ +#pragma once + +#include "ggml-cpu-impl.h" + +#include +#include +#include +#include + +namespace spacemit_kernels { + +constexpr auto div_round_up(auto up, auto down) { + return (up + down - 1) / down; +} + +// Q8 Blk [f32] [s16] [int8 * blk_len] +// Q8 Blk N [f32 * N] [s16 * N] [int8 * blk_len * N] +constexpr size_t q8_blk_size(size_t blk_len, bool with_blk_sum = false) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + (with_blk_sum ? sizeof(int16_t) : 0); + return blk_size; +} + +// Q8 HP row block: K is split into K32 subblocks. +// Each subblock stores [f32 scale] [int8 * 32], with an optional fp16 sum trailer per subblock. +constexpr size_t q8_hp_blk_size(size_t blk_len, bool with_blk_sum = false, bool with_blk_scale = false) { + const size_t subblk_count = div_round_up(blk_len, size_t(32)); + const size_t blk_size = blk_len * sizeof(int8_t) + subblk_count * sizeof(_Float16) + + (with_blk_sum ? subblk_count * sizeof(_Float16) : 0) + + (with_blk_scale ? sizeof(_Float16) : 0); + return blk_size; +} + +// Q8K Blk [f32] [s16 * (blk_len / 16)] [int8 * blk_len] +// Q8K Blk N [f32 * N] [s16 * (blk_len / 16) * N] [int8 * blk_len * N] +constexpr size_t q8k_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + sizeof(int16_t) * blk_len / 16; + return blk_size; +} + +using quantize_a_row_def = std::function; + +namespace rvv { +void memcpy1d(void * dst, const void * src, int64_t size); + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size); + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op); + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op); + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +} // namespace rvv + +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/spine_barrier.h b/ggml/src/ggml-cpu/spacemit/spine_barrier.h new file mode 100644 index 00000000000..f897dad4b8a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_barrier.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#define SPINE_CACHE_LINE 64 +#define SPINE_CACHE_ALIGN __attribute__((aligned(SPINE_CACHE_LINE))) + +struct spine_barrier_t { + SPINE_CACHE_ALIGN std::atomic pending_; + SPINE_CACHE_ALIGN std::atomic rounds_; + SPINE_CACHE_ALIGN int64_t total_; +}; + +inline void spine_barrier_wait(spine_barrier_t * b) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + while (cur_round == b->rounds_.load(std::memory_order_relaxed)) { + __asm__ volatile("pause " ::: "memory"); + } + } +} + +inline void spine_barrier_init(spine_barrier_t * b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp new file mode 100644 index 00000000000..1409423b145 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp @@ -0,0 +1,760 @@ +#include "spine_mem_pool.h" + +#include "common.h" +#include "ime_env.h" +#include "spine_tcm.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +namespace { + +constexpr size_t SPINE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull * 1024ull; +constexpr size_t SPINE_SHARE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull; +constexpr size_t SPINE_MEM_POOL_1G_REGION_SIZE = 1ull << 30; +constexpr uint64_t HUGETLB_1G_FLAG_REQUIRE_PUD = 1ull << 0; +constexpr char SPINE_MEM_POOL_HUGETLB_1G_DEV[] = "/dev/hugetlb_1g"; +constexpr char SPINE_MEM_POOL_TCM_SYNC_MEM_DEV[] = "/dev/tcm_sync_mem"; + +struct hugetlb_1g_region { + uint64_t size{ 0 }; + uint64_t dma_addr{ 0 }; + uint64_t flags{ 0 }; + uint64_t reserved{ 0 }; +}; + +#define HUGETLB_1G_IOC_MAGIC 'M' +#define HUGETLB_1G_IOC_ALLOC _IOWR(HUGETLB_1G_IOC_MAGIC, 0x00, struct hugetlb_1g_region) +#define HUGETLB_1G_IOC_FREE _IO(HUGETLB_1G_IOC_MAGIC, 0x01) + +struct free_block { + size_t offset{ 0 }; + size_t size{ 0 }; +}; + +struct pool_chunk { + uint8_t * base{ nullptr }; + size_t size{ 0 }; + int fd{ -1 }; + std::vector free_blocks; +}; + +struct pool_allocation { + void * chunk_base{ nullptr }; + size_t chunk_size{ 0 }; + void * base{ nullptr }; + size_t size{ 0 }; +}; + +bool is_power_of_two(size_t value) { + return value != 0 && (value & (value - 1)) == 0; +} + +bool align_up(size_t value, size_t alignment, size_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const size_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const size_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +bool align_up_uintptr(uintptr_t value, size_t alignment, uintptr_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const uintptr_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const uintptr_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +class spine_mem_pool_manager { + public: + explicit spine_mem_pool_manager(size_t default_chunk_size) : default_chunk_size_(default_chunk_size) {} + + virtual ~spine_mem_pool_manager() = default; + + void * alloc(size_t size, size_t alignment) { + if (size == 0 || !is_power_of_two(alignment)) { + return nullptr; + } + + size_t aligned_size = 0; + if (!align_up(size, alignment, &aligned_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: align_up failed for size %zu alignment %zu\n", __func__, size, + alignment); + return nullptr; + } + + pool_allocation allocation; + + std::lock_guard lock(mutex_); + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + if (!add_chunk_locked(aligned_size, alignment)) { + return nullptr; + } + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation retry failed for size %zu alignment %zu\n", + __func__, aligned_size, alignment); + return nullptr; + } + } + + try { + const auto [allocation_it, inserted] = allocations_.emplace(allocation.base, allocation); + if (!inserted) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: duplicate allocation key %p\n", __func__, allocation.base); + rollback_allocation_locked(allocation); + return nullptr; + } + } catch (const std::bad_alloc &) { + rollback_allocation_locked(allocation); + throw; + } + + return allocation.base; + } + + void free(void * base) { + if (base == nullptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto allocation_it = allocations_.find(base); + if (allocation_it == allocations_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown allocation %p\n", __func__, base); + return; + } + + pool_allocation allocation = allocation_it->second; + allocations_.erase(allocation_it); + + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown chunk for allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p out of chunk range %p..%p\n", __func__, + allocation.base, chunk_base, chunk_base + chunk_it->size); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p size %zu exceeds chunk size %zu\n", __func__, + allocation.base, allocation.size, chunk_it->size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + protected: + void release_chunks() { + std::lock_guard lock(mutex_); + + allocations_.clear(); + for (auto & chunk : chunks_) { + dealloc_chunk(&chunk); + } + chunks_.clear(); + } + + size_t default_chunk_size() const { return default_chunk_size_; } + + static void clear_chunk(pool_chunk * chunk) { + chunk->base = nullptr; + chunk->size = 0; + chunk->fd = -1; + chunk->free_blocks.clear(); + } + + virtual bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) = 0; + virtual void dealloc_chunk(pool_chunk * chunk) = 0; + + private: + struct alloc_candidate { + size_t chunk_index{ 0 }; + size_t block_index{ 0 }; + size_t aligned_offset{ 0 }; + uintptr_t address{ std::numeric_limits::max() }; + bool valid{ false }; + }; + + std::vector::iterator find_chunk_locked(const pool_allocation & allocation) { + return std::find_if(chunks_.begin(), chunks_.end(), [&](const pool_chunk & chunk) { + return chunk.base == allocation.chunk_base && chunk.size == allocation.chunk_size; + }); + } + + bool add_chunk_locked(size_t min_size, size_t alignment) { + pool_chunk chunk; + const size_t chunk_request = default_chunk_size_ == 0 ? min_size : std::max(min_size, default_chunk_size_); + void * hint_addr = nullptr; + + for (const auto & existing_chunk : chunks_) { + auto * chunk_end = existing_chunk.base + existing_chunk.size; + if (hint_addr == nullptr || chunk_end > hint_addr) { + hint_addr = chunk_end; + } + } + + if (!alloc_chunk(chunk_request, alignment, hint_addr, &chunk)) { + return false; + } + + if (chunk.base == nullptr || chunk.size < min_size) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: invalid chunk returned for request size %zu, chunk_base=%p chunk_size=%zu\n", + __func__, min_size, chunk.base, chunk.size); + dealloc_chunk(&chunk); + return false; + } + + try { + chunk.free_blocks.push_back({ 0, chunk.size }); + chunks_.push_back(std::move(chunk)); + } catch (const std::bad_alloc &) { + dealloc_chunk(&chunk); + throw; + } + + return true; + } + + void rollback_allocation_locked(const pool_allocation & allocation) { + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, owning chunk not found\n", + __func__, allocation.base); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, chunk range is invalid\n", + __func__, allocation.base); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + bool try_alloc_locked(size_t size, size_t alignment, pool_allocation * allocation) { + alloc_candidate best; + + for (size_t chunk_index = 0; chunk_index < chunks_.size(); ++chunk_index) { + const auto & chunk = chunks_[chunk_index]; + for (size_t block_index = 0; block_index < chunk.free_blocks.size(); ++block_index) { + const auto & block = chunk.free_blocks[block_index]; + + uintptr_t aligned_addr = 0; + const auto block_addr = reinterpret_cast(chunk.base + block.offset); + if (!align_up_uintptr(block_addr, alignment, &aligned_addr)) { + continue; + } + + if (aligned_addr < block_addr) { + continue; + } + + const size_t aligned_offset = block.offset + static_cast(aligned_addr - block_addr); + const size_t padding = aligned_offset - block.offset; + if (padding > block.size || size > block.size - padding) { + continue; + } + + if (!best.valid || aligned_addr < best.address) { + best.chunk_index = chunk_index; + best.block_index = block_index; + best.aligned_offset = aligned_offset; + best.address = aligned_addr; + best.valid = true; + } + } + } + + if (!best.valid) { + return false; + } + + auto & chunk = chunks_[best.chunk_index]; + const free_block block = chunk.free_blocks[best.block_index]; + const size_t padding = best.aligned_offset - block.offset; + const size_t alloc_end = best.aligned_offset + size; + const size_t block_end = block.offset + block.size; + + chunk.free_blocks.erase(chunk.free_blocks.begin() + best.block_index); + auto insert_it = chunk.free_blocks.begin() + best.block_index; + if (padding != 0) { + insert_it = chunk.free_blocks.insert(insert_it, { block.offset, padding }); + ++insert_it; + } + if (alloc_end < block_end) { + chunk.free_blocks.insert(insert_it, { alloc_end, block_end - alloc_end }); + } + + allocation->chunk_base = chunk.base; + allocation->chunk_size = chunk.size; + allocation->base = chunk.base + best.aligned_offset; + allocation->size = size; + return true; + } + + void maybe_release_empty_chunk_locked(std::vector::iterator chunk_it) { + if (chunk_it->free_blocks.size() != 1) { + return; + } + + const auto & block = chunk_it->free_blocks.front(); + if (block.offset != 0 || block.size != chunk_it->size) { + return; + } + + dealloc_chunk(&*chunk_it); + chunks_.erase(chunk_it); + } + + void insert_free_block_locked(pool_chunk & chunk, free_block block) { + auto it = chunk.free_blocks.begin(); + while (it != chunk.free_blocks.end() && it->offset < block.offset) { + ++it; + } + + if (it != chunk.free_blocks.begin()) { + const auto & prev = *(it - 1); + if (prev.offset + prev.size > block.offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + } + + if (it != chunk.free_blocks.end() && block.offset + block.size > it->offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping next free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + + it = chunk.free_blocks.insert(it, block); + + if (it != chunk.free_blocks.begin()) { + auto prev = it - 1; + if (prev->offset + prev->size == it->offset) { + it->offset = prev->offset; + it->size += prev->size; + it = chunk.free_blocks.erase(prev); + } + } + + if (it + 1 != chunk.free_blocks.end() && it->offset + it->size == (it + 1)->offset) { + it->size += (it + 1)->size; + chunk.free_blocks.erase(it + 1); + } + } + + std::mutex mutex_; + std::vector chunks_; + std::unordered_map allocations_; + size_t default_chunk_size_{ 0 }; +}; + +class spine_mem_pool_posix final : public spine_mem_pool_manager { + public: + spine_mem_pool_posix() : spine_mem_pool_manager(0) {} + + ~spine_mem_pool_posix() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) hint_addr; + + const size_t alloc_alignment = std::max(alignment, sizeof(void *)); + void * base = nullptr; + const int rc = posix_memalign(&base, alloc_alignment, min_size); + if (rc != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: posix_memalign failed for size %zu alignment %zu, rc=%d\n", + __func__, min_size, alloc_alignment, rc); + return false; + } + + chunk->base = static_cast(base); + chunk->size = min_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + std::free(chunk->base); + clear_chunk(chunk); + } +}; + +class spine_mem_pool_transparent_hugepage final : public spine_mem_pool_manager { + public: + spine_mem_pool_transparent_hugepage() : spine_mem_pool_manager(SPINE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_transparent_hugepage() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + size_t chunk_size = 0; + if (!align_up(min_size, default_chunk_size(), &chunk_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round chunk size for %zu\n", __func__, min_size); + return false; + } + + void * map_addr = mmap(hint_addr, chunk_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for chunk size %zu, errno=%d\n", __func__, chunk_size, + errno); + return false; + } + + if (madvise(map_addr, chunk_size, MADV_HUGEPAGE) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: madvise(MADV_HUGEPAGE) failed for chunk size %zu, errno=%d\n", + __func__, chunk_size, errno); + munmap(map_addr, chunk_size); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = chunk_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for chunk %p size %zu, errno=%d\n", __func__, + chunk->base, chunk->size, errno); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_hugetlb_1g final : public spine_mem_pool_manager { + public: + spine_mem_pool_hugetlb_1g() : spine_mem_pool_manager(SPINE_MEM_POOL_1G_REGION_SIZE) {} + + ~spine_mem_pool_hugetlb_1g() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + (void) hint_addr; + + size_t region_size = 0; + if (!align_up(min_size, SPINE_MEM_POOL_1G_REGION_SIZE, ®ion_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round hugetlb_1g size for %zu\n", __func__, min_size); + return false; + } + + const int fd = open(SPINE_MEM_POOL_HUGETLB_1G_DEV, O_RDWR); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_HUGETLB_1G_DEV, errno); + return false; + } + + hugetlb_1g_region region; + region.size = region_size; + region.flags = HUGETLB_1G_FLAG_REQUIRE_PUD; + if (ioctl(fd, HUGETLB_1G_IOC_ALLOC, ®ion) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_ALLOC failed for size %zu, errno=%d\n", __func__, + region_size, errno); + close(fd); + return false; + } + + void * map_addr = mmap(nullptr, region.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for hugetlb_1g size %llu, errno=%d\n", __func__, + static_cast(region.size), errno); + ioctl(fd, HUGETLB_1G_IOC_FREE); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = region.size; + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for hugetlb_1g chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + if (ioctl(chunk->fd, HUGETLB_1G_IOC_FREE) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_FREE failed for chunk %p, errno=%d\n", + __func__, chunk->base, errno); + } + + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_shared_mem final : public spine_mem_pool_manager { + public: + spine_mem_pool_shared_mem() : spine_mem_pool_manager(SPINE_SHARE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_shared_mem() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + if (hint_addr != nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem does not support multiple active chunks\n", __func__); + return false; + } + + if (min_size > default_chunk_size()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem request %zu exceeds chunk size %zu\n", __func__, + min_size, default_chunk_size()); + return false; + } + + const int fd = open(SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, O_RDWR | O_SYNC); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, errno); + return false; + } + + void * map_addr = mmap(nullptr, default_chunk_size(), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for %s size %zu, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, default_chunk_size(), errno); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = default_chunk_size(); + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for shared_mem chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +spine_mem_pool_manager & get_spine_mem_pool_manager() { + static std::once_flag pool_once; + static std::unique_ptr selected_pool; + static spine_mem_pool_backend selected_backend = spine_mem_pool_backend::none; + + spine_mem_pool_backend backend = global_spine_env_info.mem_backend; + if (backend == spine_mem_pool_backend::none) { + backend = spine_mem_pool_backend::transparent_hugepage; + } + + std::call_once(pool_once, [&]() { + selected_backend = backend; + + switch (selected_backend) { + case spine_mem_pool_backend::posix_memalign: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::transparent_hugepage: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::hugetlb_1g: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::none: + selected_backend = spine_mem_pool_backend::transparent_hugepage; + selected_pool = std::make_unique(); + break; + } + }); + + if (backend != selected_backend) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: mem pool backend is process-global and mutually exclusive, requested=%d but " + "selected=%d\n", + __func__, static_cast(backend), static_cast(selected_backend)); + } + + if (selected_pool) { + return *selected_pool; + } + + throw std::bad_alloc(); +} + +spine_mem_pool_manager & get_spine_mem_pool_shared_mem_manager() { + static std::once_flag shared_mem_pool_once; + static std::unique_ptr shared_mem_pool; + + std::call_once(shared_mem_pool_once, [&]() { shared_mem_pool = std::make_unique(); }); + + if (shared_mem_pool) { + return *shared_mem_pool; + } + + throw std::bad_alloc(); +} + +} // namespace + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept { + if (info == nullptr) { + return false; + } + + *info = {}; + + if (spine_tcm_open_handle(NULL) != 0 || !spine_tcm_is_available()) { + return false; + } + + spine_tcm_mem_info_t mem_info; + if (spine_tcm_mem_info(&mem_info) != 0) { + return false; + } + + info->available = true; + info->blk_size = mem_info.blk_size; + info->blk_num = mem_info.blk_num; + info->is_fake_tcm = mem_info.is_fake_tcm != 0; + return true; +} + +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept { + return spine_tcm_mem_get(cpu_id); +} + +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept { + return spine_tcm_mem_try_wait(cpu_id, 1000 * 1000); +} + +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept { + return spine_tcm_mem_release(cpu_id); +} + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating size %zu\n", __func__, size); + return nullptr; + } +} + +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_shared_mem_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating shared memory size %zu\n", __func__, size); + return nullptr; + } +} + +void spine_mem_pool_free(void * base) noexcept { + try { + get_spine_mem_pool_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing allocation %p\n", __func__, base); + } +} + +void spine_mem_pool_shared_mem_free(void * base) noexcept { + try { + get_spine_mem_pool_shared_mem_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing shared allocation %p\n", __func__, base); + } +} + +} // namespace ggml::cpu::riscv64_spacemit + +extern "C" { +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment) { + void * result = ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_alloc(size, alignment); + if (result == nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to allocate shared memory size %zu alignment %zu\n", __func__, + size, alignment); + } + return result; +} + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr) { + ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_free(ptr); +} +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h new file mode 100644 index 00000000000..8740d2c99ef --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +enum class spine_mem_pool_backend : uint8_t { + none, + posix_memalign, + transparent_hugepage, + hugetlb_1g, +}; + +struct spine_mem_pool_tcm_info { + bool available{ false }; + size_t blk_size{ 0 }; + size_t blk_num{ 0 }; + bool is_fake_tcm{ false }; +}; + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept; +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept; +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept; +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept; + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept; +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept; +void spine_mem_pool_free(void * base) noexcept; +void spine_mem_pool_shared_mem_free(void * base) noexcept; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/spine_tcm.h b/ggml/src/ggml-cpu/spacemit/spine_tcm.h new file mode 100644 index 00000000000..f300d7d5c04 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_tcm.h @@ -0,0 +1,409 @@ +#ifndef SPINE_TCM_PUBLIC_H_ +#define SPINE_TCM_PUBLIC_H_ + +/* + * spine_tcm public API + * + * Usage: + * 1. Direct link mode + * Define SPINE_TCM_DIRECT_LINK and link against libspine_tcm.so. + * + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + * + * 2. Header-only loader mode + * Include this header without linking libspine_tcm.so. The loader first + * tries to reuse a process-global spine_tcm instance and falls back to + * dlopen("libspine_tcm.so") when needed. + * + * spine_tcm_open_handle(NULL); // optional pre-bind + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + */ + +#include +#include +#include + +#if !defined(SPINE_TCM_BUILD_SHARED) && !defined(SPINE_TCM_DIRECT_LINK) +# include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_WIN32) +# if defined(SPINE_TCM_BUILD_SHARED) +# define SPINE_TCM_API __declspec(dllexport) +# else +# define SPINE_TCM_API __declspec(dllimport) +# endif +#else +# define SPINE_TCM_API __attribute__((visibility("default"))) +#endif + +typedef struct spine_tcm_mem_info { + size_t blk_size; + size_t blk_num; + int is_fake_tcm; +} spine_tcm_mem_info_t; + +typedef struct spine_tcm_block_info { + int id; + void * va; + size_t size; + uint64_t phys_addr; + uint64_t cpu_affinity_mask; + int owner_tid; + int is_acquired; +} spine_tcm_block_info_t; + +/* Shared-library runtime ABI exported by libspine_tcm.so. */ +SPINE_TCM_API const char * spine_tcm_runtime_version(void); +SPINE_TCM_API int spine_tcm_runtime_is_available(void); +SPINE_TCM_API int spine_tcm_runtime_layout_info(spine_tcm_mem_info_t * info); +SPINE_TCM_API int spine_tcm_runtime_mem_info(int id, spine_tcm_block_info_t * info); +SPINE_TCM_API void * spine_tcm_runtime_mem_get(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_free(int id); +SPINE_TCM_API void * spine_tcm_runtime_mem_try_wait(int id, size_t timeout_us); +SPINE_TCM_API int spine_tcm_runtime_mem_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_force_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_query(int id); + +#if defined(SPINE_TCM_DIRECT_LINK) +/* Optional no-op in direct-link mode. */ +static inline int spine_tcm_open_handle(const char * so_path) { + (void) so_path; + return 0; +} + +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + return spine_tcm_runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + return spine_tcm_runtime_layout_info(info); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + return spine_tcm_runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + return spine_tcm_runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + return spine_tcm_runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + return spine_tcm_runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + return spine_tcm_runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + return spine_tcm_runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + return spine_tcm_runtime_mem_query(id); +} +#elif !defined(SPINE_TCM_BUILD_SHARED) +typedef struct spine_tcm_handle { + void * module_handle; + int use_global_scope; + int owns_module_handle; + const char * (*runtime_version)(void); + int (*runtime_is_available)(void); + int (*runtime_layout_info)(spine_tcm_mem_info_t * info); + int (*runtime_mem_info)(int id, spine_tcm_block_info_t * info); + void * (*runtime_mem_get)(int id); + int (*runtime_mem_free)(int id); + void * (*runtime_mem_try_wait)(int id, size_t over_time); + int (*runtime_mem_release)(int id); + int (*runtime_mem_force_release)(int id); + int (*runtime_mem_query)(int id); +} spine_tcm_handle_t; + +static inline spine_tcm_handle_t * spine_tcm_default_handle(void) { + static spine_tcm_handle_t handle = { 0 }; + return &handle; +} + +static inline void spine_tcm_handle_reset(spine_tcm_handle_t * handle) { + if (handle != NULL) { + memset(handle, 0, sizeof(*handle)); + } +} + +static inline int spine_tcm_handle_bind(spine_tcm_handle_t * handle) { + void * symbol_scope = handle->use_global_scope ? RTLD_DEFAULT : handle->module_handle; + + handle->runtime_version = (const char * (*) (void) ) dlsym(symbol_scope, "spine_tcm_runtime_version"); + handle->runtime_is_available = (int (*)(void)) dlsym(symbol_scope, "spine_tcm_runtime_is_available"); + handle->runtime_layout_info = + (int (*)(spine_tcm_mem_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_layout_info"); + handle->runtime_mem_info = + (int (*)(int, spine_tcm_block_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_mem_info"); + handle->runtime_mem_get = (void * (*) (int) ) dlsym(symbol_scope, "spine_tcm_runtime_mem_get"); + handle->runtime_mem_free = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_free"); + handle->runtime_mem_try_wait = (void * (*) (int, size_t)) dlsym(symbol_scope, "spine_tcm_runtime_mem_try_wait"); + handle->runtime_mem_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_release"); + handle->runtime_mem_force_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_force_release"); + handle->runtime_mem_query = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_query"); + + return handle->runtime_version != NULL && handle->runtime_is_available != NULL && + handle->runtime_layout_info != NULL && handle->runtime_mem_info != NULL && + handle->runtime_mem_get != NULL && handle->runtime_mem_free != NULL && + handle->runtime_mem_try_wait != NULL && handle->runtime_mem_release != NULL && + handle->runtime_mem_force_release != NULL && handle->runtime_mem_query != NULL ? + 0 : + -1; +} + +/* + * Try to bind against an already-loaded process-global spine_tcm instance. + * The shared library exports spine_tcm_runtime_marker only for this probe. + */ +static inline int spine_tcm_try_bind_global(spine_tcm_handle_t * handle) { + if (dlsym(RTLD_DEFAULT, "spine_tcm_runtime_marker") == NULL) { + return -1; + } + + handle->use_global_scope = 1; + return spine_tcm_handle_bind(handle); +} + +/* + * Optional pre-bind entry point. + * + * Behavior: + * - Reuses an already-loaded global spine_tcm instance when available. + * - Otherwise loads the shared library from so_path or the default soname. + * - Repeated calls are safe and return 0 after the first successful bind. + */ +static inline int spine_tcm_open_handle(const char * so_path) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + const char * library = (so_path != NULL && so_path[0] != '\0') ? so_path : "libspine_tcm.so"; + + if (resolved->module_handle != NULL || resolved->use_global_scope) { + return 0; + } + + if (spine_tcm_try_bind_global(resolved) == 0) { + return 0; + } + + spine_tcm_handle_reset(resolved); + + resolved->module_handle = dlopen(library, RTLD_LAZY | RTLD_GLOBAL); + resolved->owns_module_handle = resolved->module_handle != NULL ? 1 : 0; + + if (resolved->module_handle == NULL) { + spine_tcm_handle_reset(resolved); + return -1; + } + + if (spine_tcm_handle_bind(resolved) != 0) { + if (resolved->owns_module_handle) { + dlclose(resolved->module_handle); + } + spine_tcm_handle_reset(resolved); + return -1; + } + + return 0; +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_is_available == NULL) { + return 0; + } + + return resolved->runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_layout_info == NULL) { + return -1; + } + + return resolved->runtime_layout_info(info); +} + +static inline const char * spine_tcm_version(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_version == NULL) { + return "unknown"; + } + + return resolved->runtime_version(); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_info == NULL) { + return -1; + } + + return resolved->runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_get == NULL) { + return NULL; + } + + return resolved->runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_free == NULL) { + return -1; + } + + return resolved->runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_try_wait == NULL) { + return NULL; + } + + return resolved->runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_release == NULL) { + return -1; + } + + return resolved->runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || + resolved->runtime_mem_force_release == NULL) { + return -1; + } + + return resolved->runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_query == NULL) { + return -1; + } + + return resolved->runtime_mem_query(id); +} +#else +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} +#endif + +#define SPINE_TCM_VERSION (spine_tcm_version()) + +#ifdef __cplusplus +} +#endif + +#endif From f4f2d70cb2e57449767d6c2fdfeb7a862eee8cee Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 13:05:52 +0300 Subject: [PATCH 04/55] logs : reduce (llama/23021) * logs : reduce * args : fix envs * server : fix build * common : print verbosity level at start * server : clean-up logs * server : print prompt processing timings + sampling params * minor : whitespaces --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fab7891c008..780dfe81bb3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -672,7 +672,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { ![[dev->mtl_device name] containsString:@"M6"] && ![[dev->mtl_device name] containsString:@"A19"] && ![[dev->mtl_device name] containsString:@"A20"]) { - GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); + GGML_LOG_INFO("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); dev->props.has_tensor = false; } From b333d4d84b1d43d3c4b33b6c6395de601dcec03f Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 14 May 2026 09:31:36 -0700 Subject: [PATCH 05/55] ggml-webgpu: makes the flash attn vec path subgroup-aware (llama/23040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-webgpu: makes the flash attn vec path compile and size its split/reduce work from the device’s reported subgroup range instead of assuming 32 subgroup size. * ggml-webgpu: remove the extra max_wg_size >= max_subgroup_size guard. Remove hardcoded 32 when determine the value of reduce_wg_size and vec_nwg_cap --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 +++++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 12 +++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 62a523365b9..4c4eda1cbe5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -770,9 +770,14 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && + context.src2->ne[0] % kv_vec_head_align == 0; + // Compile with enough invocations to cover the largest reported subgroup. + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && + kv_vec_head_dims_aligned && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && @@ -808,7 +813,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = 1u; decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.wg_size = context.max_subgroup_size; if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 401c75c1230..78cb02be06d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1832,7 +1832,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { @@ -1953,8 +1953,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector reduce_params; std::vector reduce_entries; if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; + const uint32_t reduce_wg_size = + std::max(reduce_sg_size, (uint32_t) std::min( + (uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -3542,8 +3545,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const uint32_t kv_tile = decisions.kv_tile; - const uint32_t vec_nwg_cap = std::max( - 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { From 9b73285ae8ecc1fbdb66336cef28e5de6530de49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 14 May 2026 22:58:58 +0200 Subject: [PATCH 06/55] HIP: RDNA3 mma FA, faster AMD transpose, tune AMD (llama/22880) Adds RDNA3 support to the CUDA mma FA kernel. To make the RDNA3 tensor cores work with the FP16 accumulation for VKQ the tiles they need to be 32 logical units long in direction of the attention head; for head sizes 80 and 112 that are not exactly divided by 32 the regular length of 16 with FP32 accumulation is used instead. The longer tiles also enable more efficient transposition for a warp size of 32 which is why it's also used for RDNA4. However, this scrambles the data layout of the accumulators along the attention head dimension. To prevent accidental misuse I added another entry to ggml_cuda_mma::data_layout. I also tuned the kernel parameters for RDNA3, RDNA4, and CDNA1 in general, during which I discovered that the kernel can be made to work for head sizes up to 256 for CDNA. For RDNA3/4 I was not able to get better performance that the tile kernel for head sizes > 128. --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 319 ++++++++++++++++++++------- ggml/src/ggml-cuda/fattn.cu | 57 ++--- ggml/src/ggml-cuda/mma.cuh | 149 +++++++++++-- 3 files changed, 398 insertions(+), 127 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 43e22c5e5ee..a25e912c4d2 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -125,61 +125,107 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true); - // TODO tune specifically for RDNA - return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { - // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async). - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true); - - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true); - // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy - // compile-time static_asserts even though the kernel guard prevents runtime execution. - // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. - return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { @@ -510,7 +556,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; @@ -712,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; + + // The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed. + // However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout. +#ifdef RDNA3 +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_j(l); + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2; + + KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]); + } +#else #pragma unroll for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { const int i = (i0 + T_C_KQ::get_j(l0)) / 2; @@ -721,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } +#endif // RDNA3 } } @@ -827,13 +886,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2( - KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -901,9 +970,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) { static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { @@ -912,15 +980,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if constexpr (T_B_KQ::I == 8) { - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); } else { // Wide version of VKQ_C is column-major. #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) // AMD matrix C is column-major. - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); #else // swap A and B for CUDA. - mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A); #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } @@ -953,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -965,7 +1033,7 @@ template struct mma_tile_sizes { using T_B_VKQ = tile<16, 8, half2>; // column-major using T_C_VKQ = tile<16, 8, half2>; // column-major }; -template<> struct mma_tile_sizes<8> { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile< 8, 8, half2>; // column-major using T_C_KQ = tile<16, 8, float>; // row-major @@ -973,8 +1041,60 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; -#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) -template struct mma_tile_sizes { +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#else +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -983,7 +1103,7 @@ template struct mma_tile_sizes { using T_C_VKQ = tile<16, 8, half2>; // column-major }; #else // Volta -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major @@ -1018,17 +1138,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; - using T_A_KQ = typename mma_tile_sizes::T_A_KQ; - using T_B_KQ = typename mma_tile_sizes::T_B_KQ; - using T_C_KQ = typename mma_tile_sizes::T_C_KQ; - using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; - using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; - using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -1061,6 +1181,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)]; #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta @@ -1269,12 +1391,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -1433,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { if constexpr (cols_per_warp == 8) { + static_assert(std::is_same_v, "bad VKQ type"); const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { @@ -1447,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } else { const int j0 = threadIdx.y*cols_per_warp; + if constexpr (std::is_same_v) { + if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) { #pragma unroll - for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - const int j = j0 + T_C_VKQ::get_i(l); - const int k = k1 + T_C_VKQ::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + } + } + } else { + static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout"); + using T_C_VKQ_us = tile; // us == unscrambled +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]); +#pragma unroll + for (int l = 0; l < T_C_VKQ_us::ne; ++l) { + const int j = j0 + T_C_VKQ_us::get_i(l); + const int k = k1 + T_C_VKQ_us::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C_us.x[l]; + } + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); + half * tile_Q_h = (half *) tile_Q; +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = 2*k1 + T_C_VKQ::get_j(l); + + tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l]; + } } } } @@ -1532,7 +1697,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } template @@ -1559,7 +1724,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1585,14 +1750,14 @@ static __global__ void flash_attn_ext_f16( #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING #if defined(AMD_WMMA_AVAILABLE) - if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) { NO_DEVICE_CODE; return; } #endif // defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) - if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) { + if (ncols1*ncols2 < 16 || DKQ > 256) { NO_DEVICE_CODE; return; } @@ -1715,7 +1880,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) } template diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index e045b04f727..1c7777e8a71 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -19,13 +19,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } if constexpr (ncols2 <= 16) { - if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + if (Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { + if (Q->ne[1] <= 32/ncols2 || (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) || + (GGML_CUDA_CC_IS_AMD(cc) && DKQ > 256)) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -477,12 +478,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + const int ncols2_max = Q->ne[0] == 320 ? 32 : ((Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8); + int gqa_ratio_eff = 1; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { - int gqa_ratio_eff = 1; - const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; - } if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -500,41 +502,22 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } - if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { - if (can_use_vector_kernel) { - if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { - if (Q->ne[1] == 1) { - if (!gqa_opt_applies) { - return BEST_FATTN_KERNEL_VEC; - } - } - } else { - if (Q->ne[1] <= 2) { - return BEST_FATTN_KERNEL_VEC; - } - } + // AMD MFMA needs a certain minimum batch size to outscale the tile kernel for large head sizes. + if ((amd_mfma_available(cc) && Q->ne[0] <= 256) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if ((Q->ne[0] <= 64 && Q->ne[1] * gqa_ratio_eff > 8)) { + return BEST_FATTN_KERNEL_MMA_F16; } - int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; + if ((Q->ne[0] <= 128 && Q->ne[1] * gqa_ratio_eff > 16)) { + return BEST_FATTN_KERNEL_MMA_F16; } - if (Q->ne[1] * gqa_ratio_eff <= 8) { - return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized. + if ((Q->ne[0] <= 256 && Q->ne[1] * gqa_ratio_eff > 64)) { + return BEST_FATTN_KERNEL_MMA_F16; } - return BEST_FATTN_KERNEL_MMA_F16; } - // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { - const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); - // MMA vs tile crossover benchmarked on MI300X @ d32768: - // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) - // hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%) - if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) { - return BEST_FATTN_KERNEL_MMA_F16; - } - // Fall through to tile kernel for small effective batch sizes. + // AMD WMMA is always faster than the tile kernel if the full tile width of 16 can be utilized. + if ((amd_wmma_available(cc) && gqa_opt_applies && Q->ne[0] <= 128) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[1] * gqa_ratio_eff > 8) { + return BEST_FATTN_KERNEL_MMA_F16; } // If there are no tensor cores available, use the generic tile kernel: diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 79bb2934c5f..8d7c69dc3e8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -80,6 +80,7 @@ namespace ggml_cuda_mma { DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. DATA_LAYOUT_J_MAJOR_MIRRORED = 30, + DATA_LAYOUT_I_MAJOR_SCRAMBLED = 40, // Scrambled matrix C for faster transposition (RDNA4/CDNA), convert to float to unscramble. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR @@ -312,13 +313,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -327,7 +334,15 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { +#ifdef RDNA3 + return l*2 + (threadIdx.x / 16); +#else + return (threadIdx.x / 16) * ne + l; +#endif // RDNA3 + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -338,13 +353,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -353,7 +374,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -516,12 +541,15 @@ namespace ggml_cuda_mma { if (I == 16 && J == 16) return true; if (I == 16 && J == 8) return true; if (I == 16 && J == 4) return true; + if (I == 32 && J == 8) return true; return false; } - static __device__ __forceinline__ int get_i(const int /*l*/) { - if constexpr (supported()) { + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16) { return threadIdx.x % 16; + } else if constexpr (I == 32) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -529,8 +557,10 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (supported()) { + if constexpr (I == 16) { return l; + } else if constexpr (I == 32) { + return l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -644,6 +674,40 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_SCRAMBLED; + + static constexpr int ne = I * J / ggml_cuda_get_physical_warp_size(); + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + }; + + static __device__ __forceinline__ tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> unscramble(const tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & t) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> ret; +#pragma unroll + for (int l0 = 0; l0 < t.ne/2; ++l0) { + ret.x[2*l0 + 0] = __lows2half2(t.x[l0], t.x[l0 + t.ne/2]); + ret.x[2*l0 + 1] = __highs2half2(t.x[l0], t.x[l0 + t.ne/2]); + } + return ret; +#else + NO_DEVICE_CODE; + GGML_UNUSED(t); +#endif // defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + } + #if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { @@ -660,6 +724,21 @@ namespace ggml_cuda_mma { ret.x[0] = ggml_cuda_movmatrix(t.x[0]); ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + return ret; + } +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + static __device__ __forceinline__ tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> get_half2( + const tile<16, 16, float, DATA_LAYOUT_I_MAJOR> & tile_float) { + tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> ret; +#pragma unroll + for (int l = 0; l < tile_float.ne; ++l) { + float tmp[2]; + int i = threadIdx.x / 16; + tmp[i] = tile_float.x[l]; + i ^= 1; + tmp[i] = __shfl_xor_sync(0xFFFFFFFF, tile_float.x[l], 16, WARP_SIZE); + ret.x[l] = make_half2(tmp[0], tmp[1]); + } return ret; } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) @@ -802,21 +881,35 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } - template + template static __device__ __forceinline__ void load_ldmatrix_trans( - tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { + tile & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE + static_assert(I == 16, "bad tile width"); + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - half * xh = (half *) t.x; + static_assert(dl == DATA_LAYOUT_I_MAJOR || dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + if constexpr (I == 32) { #pragma unroll - for (int l = 0; l < t.ne; ++l) { - xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; - xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + for (int l0 = 0; l0 < t.ne/2; ++l0) { + const half2 tmp0 = xs0[(2*t.get_j(l0) + 0)*stride + t.get_i(l0)/2]; + const half2 tmp1 = xs0[(2*t.get_j(l0) + 1)*stride + t.get_i(l0)/2]; + + t.x[l0] = __lows2half2(tmp0, tmp1); + t.x[l0 + t.ne/2] = __highs2half2(tmp0, tmp1); + } + } else { + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } } #else GGML_UNUSED_VARS(t, xs0, stride); @@ -972,6 +1065,20 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR> & B) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 8, half2> * D16 = (tile<16, 8, half2> *) &D; + const tile<16, 8, half2> * A16 = (const tile<16, 8, half2> *) &A; + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) && defined(RDNA4) + } + template static __device__ __forceinline__ void mma( tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { @@ -1296,6 +1403,22 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { +#if defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + halfx16_t * xD = (halfx16_t *) D.x; + const halfx16_t * xA = (const halfx16_t *) A.x; + const halfx16_t * xB = (const halfx16_t *) B.x; + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[0], xB[0], xD[0], /*opsel =*/ 0); + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[1], xB[0], xD[0], /*opsel =*/ 1); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { From 8436fb2843fb05eb306a8538e812075decad6cae Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Thu, 14 May 2026 16:55:54 -0700 Subject: [PATCH 07/55] ggml-hexagon: cpy: add contiguous fast-path in reshape copy (llama/23076) --- ggml/src/ggml-hexagon/htp/cpy-ops.c | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index e5b9d350fd7..5c040a32224 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -88,6 +88,29 @@ static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + // Fast path: when both src0 and dst are contiguous in memory + // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. + const bool src0_contig = (nb00 == ct->src0_type_size) && + (nb01 == ne00 * nb00) && + (nb02 == ne01 * nb01) && + (nb03 == ne02 * nb02); + const bool dst_contig = (nb0 == ct->dst_type_size) && + (nb1 == ne0 * nb0) && + (nb2 == ne1 * nb1) && + (nb3 == ne2 * nb2); + + if (src0_contig && dst_contig) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); + } + } + return; + } + // dst counters int64_t k10 = 0; int64_t i11 = 0; From 1e50c6c7a9332766e800ffc2b35e58e8f1a39a56 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 20:06:23 +0800 Subject: [PATCH 08/55] llama + spec: MTP Support (llama/22673) * spec: support MTP * fix batch size * rename files * cont : simplify (llama/7) * MTP: clean-up (llama/9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion * mtp -> draft-mtp * remove unused llama_arch * add need_embd in speculative * llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. * fix pending state * vulkan: add GDN partial rollback * meta: extend check to axis 1 * metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi * delta_net_base: use ggml_pad instead of new_tensor * review: add need_rs_seq * review: rename part_bounded to n_rs * review: deslop comments * review: rename, add asserts * server : adjust checkpoint logic (llama/11) * server : adjust checkpoint logic * cont : rm asserts * server-context: fix early exit * spec : fix compatibility with n-gram and add TODOs (llama/13) * metal : cleanup * llama : fix faulty bitwise check in recurrent memory * server : disable RS-based MTP in combination with other spec types * spec : add TODOs * cont : fix comment * cont : update comment * common : fix logic for ngram + mtp compat * llama-memory: enable checkpointing with partial rollback * cont: add test-case for loading into a dirty ctx * llama-memory-recurrent: clear rs_idx in clear * download: fix mtp path * llama-arch: fix enorm op * docs: update docs * conversion: fix type annotations --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml.h | 5 ++ ggml/src/ggml-backend-meta.cpp | 5 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- ggml/src/ggml-cpu/ops.cpp | 43 +++++++-- ggml/src/ggml-cuda/gated_delta_net.cu | 88 +++++++++++++------ ggml/src/ggml-metal/ggml-metal-device.cpp | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 46 ++++++++-- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +- .../vulkan-shaders/gated_delta_net.comp | 29 +++++- ggml/src/ggml.c | 12 +-- 10 files changed, 188 insertions(+), 57 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..41566d41aef 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2541,6 +2541,11 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a048b..df0f405ed9f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -2140,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b05edf6b75..cd5c61a8187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..7485ba4fc86 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec7317..b4c9845e7a7 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_rs_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +177,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -169,29 +191,29 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84c1..e288a27f992 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d45de8cce2..f6ffb2b3a1c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl( uint3 ntg[[threads_per_threadgroup]]) { #define S_v FC_gated_delta_net_ne20 #define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1u) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v #undef G +#undef K } typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8c4cf9ef1db..d29a4bab2e2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c1d..33c3202dbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..476c3079795 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; From 130cd40e13ab795707fc746fac4e0b9741f18326 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 May 2026 15:59:09 +0300 Subject: [PATCH 09/55] ggml : bump version to 0.12.0 (ggml/1494) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index bdeca34bf9f..4aac5094d1c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 11) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 12) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 2b85f66f1cedc8875b0ced8a797e47175258e231 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Thu, 21 May 2026 17:28:08 +0530 Subject: [PATCH 10/55] ggml-alloc: fix out-of-bounds read in ggml_dyn_tallocr_remove_block (ggml/1492) --- ggml/src/ggml-alloc.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index a4b01ccf8a1..3bda9abbe03 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -150,7 +150,7 @@ static void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t o static void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) { // shift all elements after idx by 1 to the left, overwriting the element at idx - for (int i = idx; i < chunk->n_free_blocks; i++) { + for (int i = idx; i < chunk->n_free_blocks - 1; i++) { chunk->free_blocks[i] = chunk->free_blocks[i+1]; } chunk->n_free_blocks--; From bb01f4574449a03f4f6f41209dd900ec384534a5 Mon Sep 17 00:00:00 2001 From: Ori Pekelman Date: Thu, 21 May 2026 12:00:16 +0000 Subject: [PATCH 11/55] ggml.h: correct ggml_silu_back arg docstring (a=dy, b=x) (ggml/1500) --- ggml/include/ggml.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 41566d41aef..f6725265504 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1189,8 +1189,8 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // a - x - // b - dy + // a - dy + // b - x GGML_API struct ggml_tensor * ggml_silu_back( struct ggml_context * ctx, struct ggml_tensor * a, From fcd611072edb066ee21889826a25198c763548b5 Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Sun, 17 May 2026 01:57:35 +0800 Subject: [PATCH 12/55] vulkan: removed duplicate #include in headers (llama/23144) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d29a4bab2e2..a296d0ab446 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -49,7 +49,6 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include -#include #include #include #include From 17d6153480725f73e468baafd936ed3b6354709b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 May 2026 03:25:50 -0500 Subject: [PATCH 13/55] vulkan: fuse SSM_CONV + BIAS + SILU (llama/22653) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 126 ++++++++++++++++-- .../ggml-vulkan/vulkan-shaders/ssm_conv.comp | 12 +- 2 files changed, 129 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a296d0ab446..d76d4819026 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -854,6 +854,8 @@ struct vk_device_struct { vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; + vk_pipeline pipeline_ssm_conv_silu_f32; + vk_pipeline pipeline_ssm_conv_bias_silu_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; std::map pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -4900,7 +4902,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -9936,7 +9940,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SSM_CONV: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_ssm_conv_f32; + switch (ctx->num_additional_fused_ops) { + case 0: return ctx->device->pipeline_ssm_conv_f32; + case 1: return ctx->device->pipeline_ssm_conv_silu_f32; + case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32; + default: return nullptr; + } } return nullptr; case GGML_OP_OPT_STEP_ADAMW: @@ -10877,11 +10886,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, pc, elements); } -static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { + ggml_tensor * conv = cgraph->nodes[node_idx]; + const ggml_tensor * src0 = conv->src[0]; + const ggml_tensor * src1 = conv->src[1]; + + // Pick the destination tensor (last node in the fused chain) and the optional bias. + // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu. + ggml_tensor * dst = conv; + const ggml_tensor * bias = nullptr; + + if (ctx->num_additional_fused_ops == 1) { + dst = cgraph->nodes[node_idx + 1]; // silu + } else if (ctx->num_additional_fused_ops == 2) { + ggml_tensor * add = cgraph->nodes[node_idx + 1]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + dst = cgraph->nodes[node_idx + 2]; // silu + } - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { + // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused. + const ggml_tensor * src2 = bias ? bias : src0; + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, { (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], (uint32_t)src1->nb[1], (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], @@ -13556,7 +13582,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SSM_CONV: - ggml_vk_ssm_conv(ctx, compute_ctx, node); + ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx); break; @@ -14453,6 +14479,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return true; } +// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2. +static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, int num_extra) { + const ggml_tensor * conv = cgraph->nodes[node_idx]; + if (conv->op != GGML_OP_SSM_CONV) { + return false; + } + + const ggml_tensor * silu = nullptr; + const ggml_tensor * bias = nullptr; + + if (num_extra == 1) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) { + return false; + } + silu = cgraph->nodes[node_idx + 1]; + } else if (num_extra == 2) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) { + return false; + } + const ggml_tensor * add = cgraph->nodes[node_idx + 1]; + silu = cgraph->nodes[node_idx + 2]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + // bias must be channel-wise (one element per channel of the conv output) + if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) { + return false; + } + if (add->type != GGML_TYPE_F32) { + return false; + } + // The shader doesn't apply per-tensor offsets, so reject misaligned bias. + if (get_misalign_bytes(ctx, bias) != 0) { + return false; + } + } else { + return false; + } + + if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) { + return false; + } + if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + // The shader writes to the fused dst using its own strides, but the push constants don't + // carry a per-tensor offset, so the binding must be naturally aligned. + if (get_misalign_bytes(ctx, silu) != 0) { + return false; + } + return true; +} + static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, topk_moe_mode mode) { @@ -14869,6 +14951,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // they are overwritten, and one workgroup per row. So close enough. op_srcs_fused_elementwise[0] = true; op_srcs_fused_elementwise[1] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) { + ctx->num_additional_fused_ops = 2; + fusion_string = "SSM_CONV_BIAS_SILU"; + // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs. + // The downstream add and silu are elementwise on the conv output. + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) { + ctx->num_additional_fused_ops = 1; + fusion_string = "SSM_CONV_SILU"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -15200,7 +15295,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) { ok = false; break; } @@ -15283,6 +15380,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } } } + // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD && + graph->nodes[j-1]->op == GGML_OP_SSM_CONV) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_UNARY && + graph->nodes[k]->src[0] == graph->nodes[j]) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } } } // Second pass grabs view nodes. diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index 6802b1fc955..4cd9b8da359 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -6,12 +6,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 1) const uint TOKENS_PER_WG = 16; +layout(constant_id = 2) const bool APPLY_BIAS = false; +layout(constant_id = 3) const bool APPLY_SILU = false; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; -layout(binding = 2) buffer Dst { float dst[]; }; +layout(binding = 2) readonly buffer Bias { float bias[]; }; +layout(binding = 3) buffer Dst { float dst[]; }; layout(push_constant) uniform PushConstants { uint nb01; uint nb02; @@ -45,6 +48,13 @@ void main() { } } + if (APPLY_BIAS) { + sum += bias[i1]; + } + if (APPLY_SILU) { + sum = sum / (1.0f + exp(-sum)); + } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; } From 621cbd8663fb56b66ee73e154cb3501d16c3c2eb Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 May 2026 04:30:16 -0500 Subject: [PATCH 14/55] vulkan: Support unaligned tensors for ROPE (llama/22637) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 17 +++++++++++++++++ .../ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 7 +++++-- .../ggml-vulkan/vulkan-shaders/rope_params.glsl | 3 +++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d76d4819026..14eab8ea4de 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1354,6 +1354,8 @@ struct vk_op_rope_push_constants { uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t a_offset; + uint32_t d_offset; }; static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); @@ -10126,6 +10128,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + template static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; @@ -11270,6 +11281,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * (uint32_t)src0->ne[2], nb01, nb02, nb03, nb11, nb12, nb13, + 0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets }; return rope; @@ -11365,6 +11377,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(buf[i] != nullptr); } + // a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned. + // Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset. + pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type); + offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1); + std::array elements; elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index 2e53459909d..03358793140 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -9,7 +9,7 @@ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -48,6 +48,7 @@ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); @@ -84,6 +85,7 @@ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -121,6 +123,7 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -176,7 +179,7 @@ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rop return; } - const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 2e2a7e14c66..3602485b943 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -26,6 +26,9 @@ struct rope_params { uint nb11; uint nb12; uint nb13; + + uint a_offset; + uint d_offset; }; #endif // !defined(GGML_ROPE_PARAMS) From ac163066f138ea60f94cb558a16cafb5291d1fa7 Mon Sep 17 00:00:00 2001 From: Pascal Date: Sun, 17 May 2026 11:31:20 +0200 Subject: [PATCH 15/55] vulkan: add cpy bf16 -> f32 pipelines (llama/22677) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 14 ++++++++++++-- .../ggml-vulkan/vulkan-shaders/contig_copy.comp | 8 ++++++-- ggml/src/ggml-vulkan/vulkan-shaders/copy.comp | 4 +++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 ++ 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 14eab8ea4de..d3fb19048d9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -759,8 +759,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; @@ -4572,6 +4572,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4580,6 +4581,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -7544,6 +7546,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_bf16_f32; + } else { + return ctx->device->pipeline_cpy_bf16_f32; + } + } if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_i32; @@ -15974,6 +15983,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index ca1a3ac25bd..b3b182fb084 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -19,7 +19,9 @@ void main() { if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) @@ -35,7 +37,9 @@ void main() { continue; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 9f8bfd3c182..d55e13253a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,9 @@ void main() { return; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d99b2b5d802..e3a9d61a558 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -731,6 +731,7 @@ void process_shaders() { string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); @@ -738,6 +739,7 @@ void process_shaders() { string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); From 7187e1f20392df3daa4c26a2f6b9a7d949b7c474 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ekstr=C3=B6m?= Date: Sun, 17 May 2026 14:12:11 +0300 Subject: [PATCH 16/55] ggml-vulkan/CMakeLists: add a check for SPIRV-Headers (llama/22009) * ci/run: set explicit SPIR-V Headers search path for macOS vulkan CI For whatever reason, the files are under additional sub-path `vulkan/` under the cmake directory, which does not match either current LunarG macOS Vulkan SDK structure (`lib/cmake/SPIRV-Headers`), nor what gets installed when you run the cmake build+install for SPIRV-Headers itself on at least Linux (`share/cmake/SPIRV-Headers`). This allows for SPIRV-Headers to be found, as currently the CI runner's setup does not seem to include the relevant path in list of search locations. * ggml-vulkan/CMakeLists: add a check for SPIRV-Headers This is installed by the project if it is built and installed. Receiving an error during the configuration step is generally preferred to receiving an error in the middle of a build. --- ggml/src/ggml-vulkan/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 715a263a6d0..6dbcea065b3 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,8 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +find_package(SPIRV-Headers REQUIRED) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files add_definitions(/MP) From 703eda1e18e419f671cf78fec80c3ef4bf85724a Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Sun, 17 May 2026 18:00:10 +0200 Subject: [PATCH 17/55] CUDA: Continue directly including cuda/iterator (llama/23102) Cont of #22936, forgot to update one site --- ggml/src/ggml-cuda/top-k.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 59ce36fb1c9..db1d39e2dc7 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -5,6 +5,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) # define CUB_TOP_K_AVAILABLE +# include using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB From 323bc2d04662229cf5699d7920592c001982599f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Sun, 17 May 2026 15:05:11 -0600 Subject: [PATCH 18/55] feat: Support d_conv=15 for ssm-conv.cu (llama/23017) Branch: ModalityConditionalAdapters AI-usage: none Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/ssm-conv.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4841389fbc8..4c4daf85dc6 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -140,11 +140,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa }; switch (nc) { - case 3: launch_kernel(std::integral_constant{}); break; - case 4: launch_kernel(std::integral_constant{}); break; - case 5: launch_kernel(std::integral_constant{}); break; - case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + case 15: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } From 6a5a4993388e93036ceacc877cbd4a58ffd04e4f Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Sun, 17 May 2026 22:11:51 -0700 Subject: [PATCH 19/55] sycl: route small f32 matmuls to oneMKL, bypass oneDNN (llama/22150) Signed-off-by: Chun Tao Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/ggml-sycl.cpp | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f5d10b56de0..ebe7c5b351c 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2385,21 +2385,25 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + { + const int64_t gemm_flops = (int64_t)row_diff * src1_ncols * ne10; + const bool use_mkl_direct = gemm_flops < 256 * 256 * 256; #if GGML_SYCL_DNNL - if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, - DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); - } - else + if (!g_ggml_sycl_disable_dnn && !use_mkl_direct) { + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif - { - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } } GGML_UNUSED(dst); From 01bf2afc63d70d3401afb7e8a9dd15933b1402a1 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Sun, 17 May 2026 22:12:21 -0700 Subject: [PATCH 20/55] sycl: scalar SWAR byte-subtract in Q6_K MMVQ dot product (llama/22156) Signed-off-by: Chun Tao Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/vecdotq.hpp | 99 ++++++++++++++++------------------ 1 file changed, 46 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index d7770047424..16b2d65d271 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -85,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned( (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } +static __dpct_inline__ int byte_sub_4(const int a, const int b) { + const uint32_t ua = static_cast(a); + const uint32_t ub = static_cast(b); + return static_cast(((ua | 0x80808080u) - ub) ^ 0x80808080u); +} + +static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar( + const int vl, const int vh, const int u0, const int u1, const int8_t sc0, + const int8_t sc1, const float d, const float d80, const float d81) { + static_assert(QR6_K == 2, "q6_K MMVQ scalar fast path assumes QR6_K == 2"); + + const int vil0 = (vl >> 0) & 0x0F0F0F0F; + const int vih0 = ((vh >> 0) << 4) & 0x30303030; + const int vi0 = byte_sub_4(vil0 | vih0, 0x20202020); + + const int vil1 = (vl >> 4) & 0x0F0F0F0F; + const int vih1 = ((vh >> 4) << 4) & 0x30303030; + const int vi1 = byte_sub_4(vil1 | vih1, 0x20202020); + + const float sumf = + d80 * (dpct::dp4a(vi0, u0, 0) * sc0) + + d81 * (dpct::dp4a(vi1, u1, 0) * sc1); + + return d * sumf; +} + static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4, const uint8_t *values, int &val1, int &val2) { @@ -279,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, const int *__restrict__ u, const int8_t *__restrict__ scales, const float &d, const float *__restrict__ d8) { - - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary( - (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called @@ -542,23 +552,8 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float d, const float * __restrict__ d8) { - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4 * i]; - - const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary((vil | vih), 0x20202020, - dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d * sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, @@ -579,16 +574,15 @@ template <> struct reorder_vec_dot_q_sycl { const int8_t * scs = scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; + const int u0 = get_int_from_int8_aligned( + q8_1_quant_ptr + bq8_offset * QK8_1, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + q8_1_quant_ptr + (bq8_offset + 2) * QK8_1, iqs % QI8_1); + const float d80 = (*(q8_1_ds + bq8_offset + 0))[0]; + const float d81 = (*(q8_1_ds + bq8_offset + 2))[0]; -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); - const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); - d8[i] = ds_values[0]; - } - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scs[0], scs[4], *d, d80, d81); } }; #define VDR_Q4_0_Q8_1_MMVQ 2 @@ -1167,16 +1161,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq, const int8_t * scales = bq6_K->scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; - } + const int u0 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 0].qs, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 2].qs, iqs % QI8_1); + const float d80 = bq8_1[bq8_offset + 0].ds[0]; + const float d81 = bq8_1[bq8_offset + 2].ds[0]; - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scales[0], scales[4], bq6_K->d, d80, d81); } From b94493a78a5856d1937aed66e3d1385aac178048 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Mon, 18 May 2026 13:39:36 -0700 Subject: [PATCH 21/55] ggml-hexagon: add PAD op HVX kernel (llama/23078) * ggml-hexagon: add PAD op HVX kernel Implements GGML_OP_PAD on the Hexagon HTP backend using HVX vectorized kernels. Supports zero-padding and circular padding across all 4 tensor dimensions. * hex-ggml: remove duplicate op cases (merge conflict) * hex-pad: fix editorconfig checks and macro alignment --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 18 + ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/pad-ops.c | 545 +++++++++++++++++++++++ 6 files changed, 569 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/pad-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d1c9da8329..c24a2305e4c 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2744,6 +2744,18 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_pad(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; @@ -2857,6 +2869,8 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_PAD: return HTP_OP_PAD; + case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3416,6 +3430,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_solve_tri(sess, op); break; + case GGML_OP_PAD: + supp = ggml_hexagon_supported_pad(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index bcadac11f95..36f923243cd 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -38,6 +38,7 @@ add_library(${HTP_LIB} SHARED diag-ops.c solve-tri-ops.c gated-delta-net-ops.c + pad-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 92f02eac6e3..e500ce46212 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -107,5 +107,6 @@ int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); +int op_pad(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 98db864dd42..985ded6f299 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -86,6 +86,7 @@ enum htp_op_code { HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, HTP_OP_GATED_DELTA_NET, + HTP_OP_PAD, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 883a31d6163..85569f07289 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -595,6 +595,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_SOLVE_TRI: return op_solve_tri(octx); + case HTP_OP_PAD: + return op_pad(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c new file mode 100644 index 00000000000..3abc3c2ead1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -0,0 +1,545 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +/* Circular wrap: maps any integer x into [0, n) */ +static inline uint32_t wrap_around(int32_t x, uint32_t n) { + return (uint32_t)(((x % (int32_t)n) + (int32_t)n) % (int32_t)n); +} + +/* Decompose a flat dst row index into (i1, i2, i3) */ +static inline void pad_decompose_row(uint32_t ir, uint32_t ne1, uint32_t ne2, + uint32_t *i1, uint32_t *i2, uint32_t *i3) { + *i1 = ir % ne1; + *i2 = (ir / ne1) % ne2; + *i3 = ir / (ne1 * ne2); +} + +/* Return non-zero if row (i1,i2,i3) falls in the non-padded interior */ +static inline int pad_is_interior(uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t rp1, uint32_t ne1, + int32_t lp2, int32_t rp2, uint32_t ne2, + int32_t lp3, int32_t rp3, uint32_t ne3) { + return ((int32_t)i1 >= lp1 && (int32_t)i1 < (int32_t)ne1 - rp1) && + ((int32_t)i2 >= lp2 && (int32_t)i2 < (int32_t)ne2 - rp2) && + ((int32_t)i3 >= lp3 && (int32_t)i3 < (int32_t)ne3 - rp3); +} + +/* Compute the DDR src row pointer for a zero-pad interior row */ +static inline const uint8_t * pad_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + (i1 - (uint32_t)lp1) * src->nb[1] + + (i2 - (uint32_t)lp2) * src->nb[2] + + (i3 - (uint32_t)lp3) * src->nb[3]; +} + +/* Compute the DDR src row pointer for a circular row (wrap-around indexing) */ +static inline const uint8_t * pad_circ_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + wrap_around((int32_t)i1 - lp1, src->ne[1]) * src->nb[1] + + wrap_around((int32_t)i2 - lp2, src->ne[2]) * src->nb[2] + + wrap_around((int32_t)i3 - lp3, src->ne[3]) * src->nb[3]; +} + +struct htp_pad_context { + struct htp_ops_context * octx; + + int32_t lp0, rp0; + int32_t lp1, rp1; + int32_t lp2, rp2; + int32_t lp3, rp3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; + + size_t type_size; + + // Row sizes for DMA kernel (populated when VTCM is available) + size_t src_row_size; + size_t src_row_size_aligned; + size_t dst_row_size; + size_t dst_row_size_aligned; +}; + +#define htp_pad_preamble \ + const struct htp_tensor * src = octx->src[0]; \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne00 = src->ne[0]; \ + const uint32_t nb00 = src->nb[0]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const int32_t lp0 = pctx->lp0, rp0 = pctx->rp0; \ + const int32_t lp1 = pctx->lp1, rp1 = pctx->rp1; \ + const int32_t lp2 = pctx->lp2, rp2 = pctx->rp2; \ + const int32_t lp3 = pctx->lp3, rp3 = pctx->rp3; \ + \ + const size_t type_size = pctx->type_size; \ + \ + const uint32_t row_start = pctx->nrows_per_thread * ith; \ + const uint32_t row_end = MIN(row_start + pctx->nrows_per_thread, pctx->total_dst_rows); + + +#define htp_pad_dma_preamble \ + const size_t src_row_size = pctx->src_row_size; \ + const size_t src_row_size_aligned = pctx->src_row_size_aligned; \ + const size_t dst_row_size = pctx->dst_row_size; \ + const size_t dst_row_size_aligned = pctx->dst_row_size_aligned; \ + \ + uint8_t * src_spad_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; \ + uint8_t * dst_spad_base = octx->dst_spad.data + ith * octx->dst_spad.size_per_thread; \ + \ + dma_queue * dma = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX vectorized PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_u(dst_ptr, 0.0f, ne0); + } else { + const uint8_t * src_ptr = pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (lp0 > 0) { + hvx_splat_f32_u(dst_ptr, 0.0f, (uint32_t)lp0); + } + + uint8_t * dst_row_start = dst_ptr + (size_t)lp0 * type_size; + if (nb00 == type_size) { + hvx_copy_f32_uu(dst_row_start, src_ptr, ne00); + } else { + for (uint32_t i = 0; i < ne00; i++) { + memcpy(dst_row_start + i * type_size, + src_ptr + (size_t)i * nb00, + type_size); + } + } + + if (rp0 > 0) { + hvx_splat_f32_u(dst_ptr + ((size_t)lp0 + ne00) * type_size, 0.0f, (uint32_t)rp0); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline before the main loop begins. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + const uint8_t * src_ptr = interior + ? pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3) : NULL; + + // Interior row: real DMA (1 row) from DDR to VTCM. + // Border row: null DMA (nrows=0) + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + src_ptr ? src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, src_ptr ? 1 : 0); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, compute in VTCM with aligned HVX ops, + // push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + } else { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + + uint8_t * dst_interior = dst_spad_cur + (size_t)lp0 * type_size; + + if ((uintptr_t)dst_interior % VLEN == 0) { + hvx_copy_f32_aa(dst_interior, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_interior, src_spad_cur, ne00); + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t ni1, ni2, ni3; + pad_decompose_row(next_row, ne1, ne2, &ni1, &ni2, &ni3); + const int next_interior = pad_is_interior(ni1, ni2, ni3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + const uint8_t * next_src_ptr = next_interior + ? pad_src_row_ptr(src, ni1, ni2, ni3, lp1, lp2, lp3) : NULL; + + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + next_src_ptr ? next_src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, next_src_ptr ? 1 : 0); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX circular PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + const uint8_t * src_row = pad_circ_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (nb00 == type_size) { + + if (lp0 > 0) { + if ((uint32_t)lp0 < 32) { + memcpy(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (uint32_t)lp0); + } + } + hvx_copy_f32_uu(dst_ptr + (size_t)lp0 * type_size, src_row, ne00); + if (rp0 > 0) { + if ((uint32_t)rp0 < 32) { + memcpy(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (size_t)rp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (uint32_t)rp0); + } + } + } else { + for (uint32_t i = 0; i < (uint32_t)lp0; i++) { + *(float *)(dst_ptr + i * type_size) = + *(const float *)(src_row + (size_t)(ne00 - (uint32_t)lp0 + i) * nb00); + } + for (uint32_t i = 0; i < ne00; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + for (uint32_t i = 0; i < (uint32_t)rp0; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + ne00 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA circular PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline. Every row is a real src DMA (no null DMAs). + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t pi1, pi2, pi3; + pad_decompose_row(ir, ne1, ne2, &pi1, &pi2, &pi3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, pad_circ_src_row_ptr(src, pi1, pi2, pi3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, assemble circular row in VTCM with + // aligned HVX ops, push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + + if (lp0 > 0) { + uint8_t * dst_left = dst_spad_cur; + const uint8_t * src_left = src_spad_cur + (size_t)(ne00 - (uint32_t)lp0) * type_size; + if ((uint32_t)lp0 < 32) { + memcpy(dst_left, src_left, (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_left, src_left, (uint32_t)lp0); + } + } + + { + uint8_t * dst_mid = dst_spad_cur + (size_t)lp0 * type_size; + if ((uintptr_t)dst_mid % VLEN == 0) { + hvx_copy_f32_aa(dst_mid, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_mid, src_spad_cur, ne00); + } + } + + if (rp0 > 0) { + uint8_t * dst_right = dst_spad_cur + ((size_t)lp0 + ne00) * type_size; + if ((uint32_t)rp0 < 32) { + memcpy(dst_right, src_spad_cur, (size_t)rp0 * type_size); + } else { + if ((uintptr_t)dst_right % VLEN == 0) { + hvx_copy_f32_aa(dst_right, src_spad_cur, (uint32_t)rp0); + } else { + hvx_copy_f32_ua(dst_right, src_spad_cur, (uint32_t)rp0); + } + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t nri1, nri2, nri3; + pad_decompose_row(next_row, ne1, ne2, &nri1, &nri2, &nri3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + pad_circ_src_row_ptr(src, nri1, nri2, nri3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_pad(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Only F32 supported + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + default: + FARF(ERROR, "pad-hvx: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int32_t lp0 = octx->op_params[0]; + const int32_t rp0 = octx->op_params[1]; + const int32_t lp1 = octx->op_params[2]; + const int32_t rp1 = octx->op_params[3]; + const int32_t lp2 = octx->op_params[4]; + const int32_t rp2 = octx->op_params[5]; + const int32_t lp3 = octx->op_params[6]; + const int32_t rp3 = octx->op_params[7]; + const int32_t circular = octx->op_params[8]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne00 = src0->ne[0]; + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows > 0 ? total_dst_rows : 1); + + const size_t src_row_size = (size_t)ne00 * type_size; + const size_t dst_row_size = (size_t)ne0 * type_size; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // Total VTCM needed: 2 buffers (ping+pong) for src and dst, per thread + const size_t vtcm_needed = (size_t)n_threads * 2 * (src_row_size_aligned + dst_row_size_aligned); + + const int use_dma = (src0->nb[0] == (uint32_t)type_size) && + (ne00 >= 512) && + (octx->ctx->vtcm_base != NULL) && + (octx->ctx->vtcm_size >= vtcm_needed); + + if (use_dma) { + octx->src0_spad.size_per_thread = 2 * src_row_size_aligned; + octx->dst_spad.size_per_thread = 2 * dst_row_size_aligned; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } + + struct htp_pad_context pctx = { + .octx = octx, + .lp0 = lp0, .rp0 = rp0, + .lp1 = lp1, .rp1 = rp1, + .lp2 = lp2, .rp2 = rp2, + .lp3 = lp3, .rp3 = rp3, + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + .src_row_size = src_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size = dst_row_size, + .dst_row_size_aligned = dst_row_size_aligned, + }; + + FARF(HIGH, "pad-hvx%s%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) pads=(%d,%d,%d,%d,%d,%d,%d,%d)\n", + circular ? "-circ" : "", + use_dma ? "-dma" : "", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + + if (circular && use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular_dma, &pctx, n_threads); } + else if (circular) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular, &pctx, n_threads); } + else if (use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_dma, &pctx, n_threads); } + else { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx, &pctx, n_threads); } + + return HTP_STATUS_OK; +} + From c4631bbcb9ac5950619781e5bdf628ce2d56762d Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Mon, 18 May 2026 14:04:57 -0700 Subject: [PATCH 22/55] hexagon: add support for TRI op (llama/22822) * Hexagon: TRI HVX Kernel addition to ggml hexagon HTP ops and context * addressed PR review comments for TRI op * hexagon: clang format * hex-unary: remove merge conflict markers * hex-ggml: remove duplicate op cases (merge conflict) * hex-ggml: fix editor config errors --------- Co-authored-by: Todor Boinovski Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 20 +++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/unary-ops.c | 113 ++++++++++++++++++++++++- 5 files changed, 137 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index c24a2305e4c..2f75e97ac66 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2828,6 +2828,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { return false; } + if (dst->type != GGML_TYPE_F32) { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; } + + return true; + + GGML_UNUSED(sess); +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2869,6 +2884,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_TRI: return HTP_OP_TRI; case GGML_OP_PAD: return HTP_OP_PAD; case GGML_OP_UNARY: @@ -3430,6 +3446,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_solve_tri(sess, op); break; + case GGML_OP_TRI: + supp = ggml_hexagon_supported_tri(sess, op); + break; + case GGML_OP_PAD: supp = ggml_hexagon_supported_pad(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index e500ce46212..6fe3e6c7d85 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -107,6 +107,7 @@ int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); +int op_tri(struct htp_ops_context * octx); int op_pad(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 985ded6f299..676e948a439 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -86,6 +86,7 @@ enum htp_op_code { HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, HTP_OP_GATED_DELTA_NET, + HTP_OP_TRI, HTP_OP_PAD, HTP_OP_INVALID diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 85569f07289..12003c1fd8a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -601,6 +601,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); + case HTP_OP_TRI: + return op_tri(octx); + case HTP_OP_INVALID: break; diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index d4ae89ee6f0..1ce881353ec 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -17,7 +17,6 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "htp-ops.h" struct htp_unary_context { struct htp_ops_context * octx; @@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src, } } +static void tri_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + const uint32_t ir, + const struct htp_unary_context * uctx) { + + const int32_t ttype = op_params[0]; + const HVX_Vector zero = hvx_vec_splat_f32(0.0f); + const uint32_t nvec = row_elems / VLEN_FP32; + const uint32_t nloe = row_elems % VLEN_FP32; + + const uint32_t ne01 = uctx->octx->src[0]->ne[1]; + + for (uint32_t b = 0; b < num_rows; b++) { + const uint32_t abs_row = ir + b; + const uint32_t i01 = abs_row % ne01; + + const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size); + HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size); + + uint32_t boundary; + int keep_left; + switch (ttype) { + case 0: boundary = i01; keep_left = 0; break; // keep col >= row + case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row + case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row + case 3: boundary = i01; keep_left = 1; break; // keep col < row + default: boundary = 0; keep_left = 0; break; + } + if (boundary > row_elems) boundary = row_elems; + + // Full HVX vectors — each starts at a 128-byte aligned offset + for (uint32_t i = 0; i < nvec; i++) { + const uint32_t vec_start = i * VLEN_FP32; + const uint32_t vec_end = vec_start + VLEN_FP32; + if (keep_left) { + if (vec_end <= boundary) { + v_dst[i] = v_src[i]; + } else if (vec_start >= boundary) { + v_dst[i] = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero); + } + } else { + if (vec_end <= boundary) { + v_dst[i] = zero; + } else if (vec_start >= boundary) { + v_dst[i] = v_src[i]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]); + } + } + } + + // Tail elements (row_elems not a multiple of VLEN_FP32) + if (nloe > 0) { + const uint32_t vec_start = nvec * VLEN_FP32; + const uint32_t vec_end = vec_start + nloe; + HVX_Vector tail_val; + if (keep_left) { + if (vec_end <= boundary) { + tail_val = v_src[nvec]; + } else if (vec_start >= boundary) { + tail_val = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero); + } + } else { + if (vec_end <= boundary) { + tail_val = zero; + } else if (vec_start >= boundary) { + tail_val = v_src[nvec]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]); + } + } + hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val); + } + } +} + static void softplus_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -498,6 +586,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_L2_NORM: l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_TRI: + tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx); + break; default: break; } @@ -571,6 +662,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_L2_NORM: op_type = "l2norm-f32"; break; + case HTP_OP_TRI: + op_type = "tri-f32"; + break; + default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; @@ -640,6 +735,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return err; } +int op_tri(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src[0]->type) { + case HTP_TYPE_F32: + err = execute_op_unary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} + int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; From 3d73095d51cd981bb7fdd37bbd3392ec7f1cabc2 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 19 May 2026 09:42:36 +0300 Subject: [PATCH 23/55] rpc : keep last_graph_uid in the device context (llama/23273) With the introduction of MTP we can have multiple compute contexts for the same RPC device. In this case last_graph_uid is not updated properly when contexts are being switched. This patch fixes this by moving last_graph_uid to the device context, making sure it is always updated. closes: #23242 --- ggml/src/ggml-rpc/ggml-rpc.cpp | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1cb8f563d85..d3805772183 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() { return &guid; } +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; + uint64_t last_graph_uid; +}; + struct ggml_backend_rpc_buffer_type_context { std::string endpoint; uint32_t device; @@ -211,7 +219,6 @@ struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - uint64_t last_graph_uid; }; struct ggml_backend_rpc_buffer_context { @@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend); + ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; + bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->last_graph_uid = cgraph->uid; + rpc_dev_ctx->last_graph_uid = cgraph->uid; std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { /* .endpoint = */ endpoint, /* .device = */ device, /* .name = */ dev_name, - /* .last_graph_uid = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } } -// device interface - -struct ggml_backend_rpc_device_context { - std::string endpoint; - uint32_t device; - std::string name; - std::string description; -}; - static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; @@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { std::string dev_name = "RPC" + std::to_string(dev_id); std::string dev_desc = std::string(endpoint); ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .device = */ ind, - /* .name = */ dev_name, - /* .description = */ dev_desc + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc, + /* .last_graph_uid = */ 0, }; ggml_backend_dev_t dev = new ggml_backend_device { From 092d46616f29294dcf916fa4412bf3248cbcf0ef Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Mon, 18 May 2026 23:44:02 -0700 Subject: [PATCH 24/55] sycl: add GGML_SYCL_USE_ASYNC_MEM_OP env toggle (llama/22153) * sycl: add GGML_SYCL_USE_ASYNC_MEM_OP env toggle Signed-off-by: Chun Tao * Use async mem ops for correctness when SYCL graphs are explicitly on. Signed-off-by: Tao, Chun --------- Signed-off-by: Chun Tao Signed-off-by: Tao, Chun Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/ggml-sycl.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ebe7c5b351c..2ea47f7153a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -72,6 +72,7 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_use_async_mem_op_requested = 1; int g_ggml_sycl_enable_level_zero = 0; int g_ggml_sycl_enable_flash_attention = 1; @@ -304,6 +305,8 @@ static void ggml_check_sycl() try { GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); + g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); + GGML_LOG_INFO(" GGML_SYCL_USE_ASYNC_MEM_OP: %d\n", g_ggml_sycl_use_async_mem_op_requested); #ifdef SYCL_FLASH_ATTN GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention); @@ -319,11 +322,11 @@ static void ggml_check_sycl() try { fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); #endif */ - // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be - // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in - // other places. + // Async USM allocation/free is also useful outside the graph path: it avoids the host waits in the reorder + // staging path while preserving queue ordering semantics. Graph support still depends on the extension being + // available, but it no longer needs to control the non-graph fast path. #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC - g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph; + g_ggml_sycl_use_async_mem_op = g_ggml_sycl_use_async_mem_op_requested || !g_ggml_sycl_disable_graph; if (g_ggml_sycl_use_async_mem_op) { for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) { if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) { From 3fbd4a796ef4b8269bb7450b4d6bc5107344d9ba Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 18 May 2026 23:45:41 -0700 Subject: [PATCH 25/55] ggml-webgpu : extend GDN for K>1 (llama/23299) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 ++ .../wgsl-shaders/gated_delta_net.wgsl | 24 +++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 78cb02be06d..921c12b41ac 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const uint32_t K = (uint32_t) src5->ne[1]; const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); @@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, (uint32_t) src0->ne[1], (uint32_t) (src2->ne[3] / src0->ne[3]), + K, scale_u32, }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index f9d98fda40b..d68520f8282 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -39,6 +39,7 @@ struct Params { neq1: u32, rq3: u32, + K: u32, scale: f32, }; @@ -62,11 +63,14 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_base = (seq_id * params.h + head_id) * state_size; + let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + let state_out_base = (seq_id * params.h + head_id) * state_size; + let state_size_per_snap = state_size * params.h * params.n_seqs; + let shift = i32(params.n_tokens) - i32(params.K); var state: array; for (var i = 0u; i < S_V; i++) { - state[i] = src_state[state_base + col * S_V + i]; + state[i] = src_state[state_in_base + col * S_V + i]; } var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; @@ -123,10 +127,22 @@ fn main( dst[attn_off + col] = attn_col * params.scale; attn_off += S_V * params.h; + if (params.K > 1u) { + let target_slot = i32(t) - shift; + if (target_slot >= 0 && target_slot < i32(params.K)) { + let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; + for (var i = 0u; i < S_V; i++) { + dst[slot_base + col * S_V + i] = state[i]; + } + } + } + workgroupBarrier(); } - for (var i = 0u; i < S_V; i++) { - dst[params.s_off + state_base + col * S_V + i] = state[i]; + if (params.K == 1u) { + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_out_base + col * S_V + i] = state[i]; + } } } From 05bf9c48d344c0a64a46fe3653013347abc042e9 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Tue, 19 May 2026 22:18:21 +0530 Subject: [PATCH 26/55] hexagon: enable support for NORM op (llama/23319) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 97 ++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 2f75e97ac66..ebeef3bdbaf 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2870,6 +2870,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; case GGML_OP_SCALE: return HTP_OP_SCALE; @@ -3338,10 +3339,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_NORM: case GGML_OP_L2_NORM: - supp = ggml_hexagon_supported_unary(sess, op); - break; - case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 676e948a439..9d905a30133 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -88,6 +88,7 @@ enum htp_op_code { HTP_OP_GATED_DELTA_NET, HTP_OP_TRI, HTP_OP_PAD, + HTP_OP_NORM, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 12003c1fd8a..8e54536f619 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -534,6 +534,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_ADD_ID: return op_binary(octx); + case HTP_OP_NORM: case HTP_OP_RMS_NORM: case HTP_OP_SCALE: case HTP_OP_SQR: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 1ce881353ec..40d2d60153a 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares and sum of values for full vectors + HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Reduce HVX sums + sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v)); + sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v); + HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v); + HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v)); + HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v); + HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v); + + // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); + HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v3); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v3); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + static void scale_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src, } } +static void norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } +} + static void sqr_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * // Process block in VTCM switch (htp_op) { + case HTP_OP_NORM: + norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const char * op_type = NULL; switch (octx->op) { + case HTP_OP_NORM: + op_type = "norm-f32"; + break; case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break; From 752744d5a237b063073963db6b30c514752f486d Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 20 May 2026 02:40:13 +0530 Subject: [PATCH 27/55] hexagon: add MROPE and IMROPE support in HTP rope op (llama/23317) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 115 +++++++++++++++++++++---- 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ebeef3bdbaf..080fb7f47e3 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2661,7 +2661,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess int mode = op_params[2]; - if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + if (mode == GGML_ROPE_TYPE_VISION) { return false; } if (mode & 1) { diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 1d8b0796bc9..9901453e91e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -18,9 +18,11 @@ #include "htp-ops.h" #include "htp-ops.h" -// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h +// Redefined the rope type constants as we can't include ggml.h #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_TYPE_MROPE 8 +#define HTP_ROPE_TYPE_IMROPE 40 #define HTP_ROPE_SPAD_NROWS 16 #define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) @@ -82,6 +84,29 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return (1 - MIN(1, MAX(0, y))); } +// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling. +static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims, + uint32_t i0, float ext_factor, float mscale, + float * cache) { + float theta_extrap = theta; + + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta_final = theta_interp; + float mscale_final = mscale; + + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; +} + static void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, @@ -96,26 +121,62 @@ static void rope_cache_init(const float theta_base, for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - float theta_extrap = theta / ff; - - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta_final = theta_interp; - float mscale_final = mscale; + theta *= theta_scale; + } +} - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; +// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). +// sections[4]: number of head dims assigned to each position component. +static void mrope_cache_init(const float pos_t, + const float pos_h, + const float pos_w, + const float pos_e, + const int32_t sections[4], + const bool is_imrope, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { + const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + const int sec_w = sections[0] + sections[1]; + const int sec_e = sec_w + sections[2]; + + float theta_t = pos_t; + float theta_h = pos_h; + float theta_w = pos_w; + float theta_e = pos_e; - // Get n-d magnitude scaling corrected for interpolation - mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + const int sector = (i0 / 2) % sect_dims; + + float theta; + if (is_imrope) { + // Interleaved: sector mod 3 selects component + if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; } + else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; } + else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; } + else { theta = theta_e; } + } else { + // Contiguous sections + if (sector < sections[0]) { theta = theta_t; } + else if (sector < sec_w) { theta = theta_h; } + else if (sector < sec_e) { theta = theta_w; } + else { theta = theta_e; } } - cache[i0 + 0] = cosf(theta_final) * mscale_final; - cache[i0 + 1] = sinf(theta_final) * mscale_final; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - theta *= theta_scale; + theta_t *= theta_scale; + theta_h *= theta_scale; + theta_w *= theta_scale; + theta_e *= theta_scale; } } @@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { uint64_t tt = HAP_perf_get_qtimer_count(); const int32_t mode = rctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + // MROPE and IMROPE use NEOX-style pairing for the rotation + const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE); // VTCM setup uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); @@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { if (i2 != prev_i2) { prev_i2 = i2; - const int32_t p = pos[i2]; - rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); + const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0; + if (is_mrope) { + // src1 holds four position arrays stacked along ne0: + // pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3] + const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE); + mrope_cache_init( + (float) pos[i2], + (float) pos[i2 + ne2], + (float) pos[i2 + ne2 * 2], + (float) pos[i2 + ne2 * 3], + rctx->sections, is_imrope, + rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } else { + rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); From 21612e65419764a0d08d73c2174585d4a925a19e Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Tue, 19 May 2026 14:29:00 -0700 Subject: [PATCH 28/55] opencl: add MoE support for q4_k, q5_k, q6_k on Adreno (llama/23303) * opencl: add q4_k moe support * opencl: add q5_k moe support * opencl: add q6_k moe support * opencl: adjust format --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 6 + ggml/src/ggml-opencl/ggml-opencl.cpp | 948 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 385 +++++++ .../kernels/gemm_moe_q4_k_f32_ns.cl | 279 ++++++ .../kernels/gemm_moe_q5_k_f32_ns.cl | 284 ++++++ .../kernels/gemm_moe_q6_k_f32_ns.cl | 263 +++++ .../kernels/gemv_moe_q4_k_f32_ns.cl | 151 +++ .../kernels/gemv_moe_q5_k_f32_ns.cl | 156 +++ .../kernels/gemv_moe_q6_k_f32_ns.cl | 137 +++ 9 files changed, 2601 insertions(+), 8 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index c6aba608736..f75d089b574 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -110,6 +110,12 @@ set(GGML_OPENCL_KERNELS gemv_moe_q5_0_f32_ns gemm_moe_q5_1_f32_ns gemv_moe_q5_1_f32_ns + gemm_moe_q4_k_f32_ns + gemv_moe_q4_k_f32_ns + gemm_moe_q5_k_f32_ns + gemv_moe_q5_k_f32_ns + gemm_moe_q6_k_f32_ns + gemv_moe_q6_k_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0e511592d53..a3af8c2da41 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -558,6 +558,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; + cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; + cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; + cl_kernel kernel_convert_block_q6_k_trans4_ns, kernel_restore_block_q6_k_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -619,6 +622,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns; cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns; + cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns; + cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns; + cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -981,6 +987,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_k_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -3071,6 +3083,108 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -4148,6 +4262,8 @@ struct ggml_tensor_extra_cl_iq4_nl { struct ggml_tensor_extra_cl_q4_K { // Quantized values cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Scales for each super block. cl_mem s = nullptr; // Scales @@ -4176,12 +4292,18 @@ struct ggml_tensor_extra_cl_q4_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } } }; struct ggml_tensor_extra_cl_q5_K { // Lower 4 bits of quantized weights. cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Upper 1 bit of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4222,6 +4344,10 @@ struct ggml_tensor_extra_cl_q5_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } size_q = 0; size_qh = 0; @@ -4234,6 +4360,8 @@ struct ggml_tensor_extra_cl_q5_K { struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; + // Lower 4 bits as image1d_buffer_t + cl_mem ql_img = nullptr; // Upper 2 bits of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4267,6 +4395,10 @@ struct ggml_tensor_extra_cl_q6_K { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + if (ql_img != nullptr) { + CL_CHECK(clReleaseMemObject(ql_img)); + ql_img = nullptr; + } size_ql = 0; size_qh = 0; @@ -4700,7 +4832,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te // the quantizations here currently do not - they are only supported by Adreno with certain shapes if (op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || - op->src[0]->type == GGML_TYPE_Q5_1) { + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (op->src[1]->type == GGML_TYPE_F32) { return use_adreno_moe_kernels(backend_ctx, op->src[0]) @@ -6047,14 +6182,57 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - #endif +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -6157,14 +6335,58 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; - #endif +#endif cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -6232,6 +6454,79 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno MoE Q6_K kernel needs special transposed layout + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + size_t moe_size_ql = (size_t)(ggml_nelements(tensor) / 8) * sizeof(uint32_t); // 4 bits per element + size_t moe_size_qh = (size_t)(ggml_nelements(tensor) / 16) * sizeof(uint32_t); // 2 bits per element + size_t moe_size_s = size_s; + size_t moe_size_d = size_d; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = moe_size_ql; + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + moe_size_ql, backend_ctx->alignment); + region.size = moe_size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + moe_size_qh, backend_ctx->alignment); + region.size = moe_size_s; + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for d + region.origin = align_to(previous_origin + moe_size_s, backend_ctx->alignment); + region.size = moe_size_d; + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for ql + cl_image_format img_format_ql = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_ql = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->ql } + }; + extra->ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_ql, &img_desc_ql, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // Subbuffer for ql region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_ql; @@ -6825,6 +7120,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6901,6 +7230,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6974,7 +7337,44 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; @@ -13733,6 +14133,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13741,6 +14144,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, (void)extra0_q4_1; (void)extra0_q5_0; (void)extra0_q5_1; + (void)extra0_q4_K; + (void)extra0_q5_K; + (void)extra0_q6_K; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -14612,6 +15018,532 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q6_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q6_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q6_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 8f06d570587..312366984b6 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -664,6 +664,391 @@ kernel void kernel_restore_block_q5_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +kernel void kernel_convert_block_q4_k_trans4_ns( + __global struct block_q4_K * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_q[base + (p * 4 + 0) * ne01] = v.x; + dst_q[base + (p * 4 + 1) * ne01] = v.y; + dst_q[base + (p * 4 + 2) * ne01] = v.z; + dst_q[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_k_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q4_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_q[base + (p * 4 + 0) * ne01]; + qv[p].y = src_q[base + (p * 4 + 1) * ne01]; + qv[p].z = src_q[base + (p * 4 + 2) * ne01]; + qv[p].w = src_q[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q5_k_trans4_ns( + __global struct block_q5_K * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + for (int k = 0; k < 8; k++) { + uchar b0 = 0, b1 = 0, b2 = 0, b3 = 0; + for (int bit = 0; bit < 8; bit++) { + b0 |= (uchar)(((b->qh[bit] >> k) & 1) << bit); + b1 |= (uchar)(((b->qh[8 + bit] >> k) & 1) << bit); + b2 |= (uchar)(((b->qh[16 + bit] >> k) & 1) << bit); + b3 |= (uchar)(((b->qh[24 + bit] >> k) & 1) << bit); + } + uint packed = (uint)b0 | ((uint)b1 << 8) | ((uint)b2 << 16) | ((uint)b3 << 24); + dst_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01] = packed; + } + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_qs[base + (p * 4 + 0) * ne01] = v.x; + dst_qs[base + (p * 4 + 1) * ne01] = v.y; + dst_qs[base + (p * 4 + 2) * ne01] = v.z; + dst_qs[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_k_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q5_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + for (int j = 0; j < 32; j++) b->qh[j] = 0; + for (int k = 0; k < 8; k++) { + uint packed = src_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01]; + uchar b0 = (uchar)(packed & 0xFF); + uchar b1 = (uchar)((packed >> 8) & 0xFF); + uchar b2 = (uchar)((packed >> 16) & 0xFF); + uchar b3 = (uchar)((packed >> 24) & 0xFF); + for (int bit = 0; bit < 8; bit++) { + b->qh[bit] |= (uchar)(((b0 >> bit) & 1) << k); + b->qh[8 + bit] |= (uchar)(((b1 >> bit) & 1) << k); + b->qh[16 + bit] |= (uchar)(((b2 >> bit) & 1) << k); + b->qh[24 + bit] |= (uchar)(((b3 >> bit) & 1) << k); + } + } + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_qs[base + (p * 4 + 0) * ne01]; + qv[p].y = src_qs[base + (p * 4 + 1) * ne01]; + qv[p].z = src_qs[base + (p * 4 + 2) * ne01]; + qv[p].w = src_qs[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q6_k_trans4_ns( + __global struct block_q6_K * src0, + __global uint * dst_ql, + __global uint * dst_qh, + __global half * dst_d, + __global char * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = src0 + src_blk_offset; + + dst_d[dst_blk_offset] = b->d; + + uint4 qlv[8]; + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->ql[i*64 + 2*j]; + uchar x1 = b->ql[i*64 + 2*j + 1]; + uchar x2 = b->ql[i*64 + 32 + 2*j]; + uchar x3 = b->ql[i*64 + 32 + 2*j + 1]; + qlv_bytes[i*64 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qlv_bytes[i*64 + j + 16] = convert_uchar(x2 & mask_0F) | convert_uchar((x3 & mask_0F) << 4); + qlv_bytes[i*64 + j + 32] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + qlv_bytes[i*64 + j + 48] = convert_uchar((x2 & mask_F0) >> 4) | convert_uchar(x3 & mask_F0); + } + } + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qlv[p]; + dst_ql[ql_base + (p * 4 + 0) * ne01] = v.x; + dst_ql[ql_base + (p * 4 + 1) * ne01] = v.y; + dst_ql[ql_base + (p * 4 + 2) * ne01] = v.z; + dst_ql[ql_base + (p * 4 + 3) * ne01] = v.w; + } + + uint qhv[16] = {0}; + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + uchar h = b->qh[n*32 + l]; + int u = l / 16; + int bit_pos = (l % 16) * 2; + qhv[(n*4 + 0)*2 + u] |= ((uint)((h >> 0) & 0x03)) << bit_pos; + qhv[(n*4 + 1)*2 + u] |= ((uint)((h >> 2) & 0x03)) << bit_pos; + qhv[(n*4 + 2)*2 + u] |= ((uint)((h >> 4) & 0x03)) << bit_pos; + qhv[(n*4 + 3)*2 + u] |= ((uint)((h >> 6) & 0x03)) << bit_pos; + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + + for (int p = 0; p < 16; ++p) { + dst_qh[qh_base + p * ne01] = qhv[p]; + } + + __global char * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + #pragma unroll + for (int i = 0; i < 16; ++i) { + s_dst[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_k_trans4_ns( + __global uint * src_ql, + __global uint * src_qh, + __global half * src_d, + __global char * src_s, + __global struct block_q6_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + uint4 qlv[8]; + for (int p = 0; p < 8; ++p) { + qlv[p].x = src_ql[ql_base + (p * 4 + 0) * ne01]; + qlv[p].y = src_ql[ql_base + (p * 4 + 1) * ne01]; + qlv[p].z = src_ql[ql_base + (p * 4 + 2) * ne01]; + qlv[p].w = src_ql[ql_base + (p * 4 + 3) * ne01]; + } + + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo_02 = qlv_bytes[i*64 + j]; + uchar lo_13 = qlv_bytes[i*64 + j + 16]; + uchar hi_02 = qlv_bytes[i*64 + j + 32]; + uchar hi_13 = qlv_bytes[i*64 + j + 48]; + b->ql[i*64 + 2*j] = convert_uchar((lo_02 & mask_0F) | ((hi_02 & mask_0F) << 4)); + b->ql[i*64 + 2*j + 1] = convert_uchar(((lo_02 & mask_F0) >> 4) | (hi_02 & mask_F0)); + b->ql[i*64 + 32 + 2*j] = convert_uchar((lo_13 & mask_0F) | ((hi_13 & mask_0F) << 4)); + b->ql[i*64 + 32 + 2*j + 1] = convert_uchar(((lo_13 & mask_F0) >> 4) | (hi_13 & mask_F0)); + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + uint qhv[16]; + for (int p = 0; p < 16; ++p) { + qhv[p] = src_qh[qh_base + p * ne01]; + } + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + int u = l / 16; + int bit_pos = (l % 16) * 2; + uchar v0 = (uchar)((qhv[(n*4 + 0)*2 + u] >> bit_pos) & 0x03); + uchar v1 = (uchar)((qhv[(n*4 + 1)*2 + u] >> bit_pos) & 0x03); + uchar v2 = (uchar)((qhv[(n*4 + 2)*2 + u] >> bit_pos) & 0x03); + uchar v3 = (uchar)((qhv[(n*4 + 3)*2 + u] >> bit_pos) & 0x03); + b->qh[n*32 + l] = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6); + } + } + + __global char * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + for (int i = 0; i < 16; ++i) { + b->scales[i] = s_src[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..9d24aff6a20 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -0,0 +1,279 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q4_k(q4, a_f16, scale, minv) \ + a_f16.s0 = (half)((float)(q4.s0 & 0x000F) * scale - minv); \ + a_f16.s1 = (half)((float)((q4.s0 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s2 = (half)((float)((q4.s0 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s3 = (half)((float)((q4.s0 & 0xF000) >> 12) * scale - minv); \ + a_f16.s4 = (half)((float)(q4.s1 & 0x000F) * scale - minv); \ + a_f16.s5 = (half)((float)((q4.s1 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s6 = (half)((float)((q4.s1 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s7 = (half)((float)((q4.s1 & 0xF000) >> 12) * scale - minv); \ + a_f16.s8 = (half)((float)(q4.s2 & 0x000F) * scale - minv); \ + a_f16.s9 = (half)((float)((q4.s2 & 0x00F0) >> 4) * scale - minv); \ + a_f16.sa = (half)((float)((q4.s2 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sb = (half)((float)((q4.s2 & 0xF000) >> 12) * scale - minv); \ + a_f16.sc = (half)((float)(q4.s3 & 0x000F) * scale - minv); \ + a_f16.sd = (half)((float)((q4.s3 & 0x00F0) >> 4) * scale - minv); \ + a_f16.se = (half)((float)((q4.s3 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sf = (half)((float)((q4.s3 & 0xF000) >> 12) * scale - minv); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q4_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half (next 16 elements, same sub-block scale) + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..808a0c7db6a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -0,0 +1,284 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q5_k(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((float)(( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((float)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.s2 = (half)((float)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.s3 = (half)((float)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.s4 = (half)((float)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.s5 = (half)((float)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.s6 = (half)((float)(((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((float)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m)); \ + a_f16.s8 = (half)((float)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m)); \ + a_f16.s9 = (half)((float)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.sa = (half)((float)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.sb = (half)((float)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.sc = (half)((float)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.sd = (half)((float)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.se = (half)((float)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m)); \ + a_f16.sf = (half)((float)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m)); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q5_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uint * src0_qh, + __global uchar * src0_s, + __global half * src0_d, + __global half * src0_dm, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // qh is stored at sub-block granularity + uint qh_offset = row + sub * ne01 + expert_id * num_superblocks * 8 * ne01 + get_global_id(0); + uchar4 qhx32 = as_uchar4(src0_qh[qh_offset]); + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_k(as_ushort4(q4x16), qhx32.lo, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q5_k(as_ushort4(q4x16), qhx32.hi, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..a040335adfa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -0,0 +1,263 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 + +#define dequantize_q6_k(qs16, qh16, a_f16, scale) \ + a_f16.s0 = (half)(((float)(( qs16.s0 & 0x000F) | ((uint)(( qh16 ) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s1 = (half)(((float)((( qs16.s0 >> 4) & 0x000F) | ((uint)(( qh16 >> 2) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s2 = (half)(((float)((( qs16.s0 >> 8) & 0x000F) | ((uint)(( qh16 >> 4) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s3 = (half)(((float)((( qs16.s0 >>12) & 0x000F) | ((uint)(( qh16 >> 6) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s4 = (half)(((float)(( qs16.s1 & 0x000F) | ((uint)(( qh16 >> 8) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s5 = (half)(((float)((( qs16.s1 >> 4) & 0x000F) | ((uint)(( qh16 >> 10) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s6 = (half)(((float)((( qs16.s1 >> 8) & 0x000F) | ((uint)(( qh16 >> 12) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s7 = (half)(((float)((( qs16.s1 >>12) & 0x000F) | ((uint)(( qh16 >> 14) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s8 = (half)(((float)(( qs16.s2 & 0x000F) | ((uint)(( qh16 >> 16) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s9 = (half)(((float)((( qs16.s2 >> 4) & 0x000F) | ((uint)(( qh16 >> 18) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sa = (half)(((float)((( qs16.s2 >> 8) & 0x000F) | ((uint)(( qh16 >> 20) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sb = (half)(((float)((( qs16.s2 >>12) & 0x000F) | ((uint)(( qh16 >> 22) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sc = (half)(((float)(( qs16.s3 & 0x000F) | ((uint)(( qh16 >> 24) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sd = (half)(((float)((( qs16.s3 >> 4) & 0x000F) | ((uint)(( qh16 >> 26) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.se = (half)(((float)((( qs16.s3 >> 8) & 0x000F) | ((uint)(( qh16 >> 28) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sf = (half)(((float)((( qs16.s3 >>12) & 0x000F) | ((uint)(( qh16 >> 30) & 0x3) << 4)) - 32.f) * scale); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q6_k_f32_ns( + __read_only image1d_buffer_t src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * 16; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; // 32-element group index + uint sb = sub / 8; // super-block index + uint j = sub % 8; // group within super-block + + // Load d for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + + // Load sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + uint qh_base = row + (sub * 2) * ne01 + expert_id * (num_superblocks * 16) * ne01 + get_global_id(0); + uint qh_first16 = src0_qh[qh_base]; + uint qh_second16 = src0_qh[qh_base + ne01]; + + // First half (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 ql nibbles (2 uints) from image + uint2 q4x16; + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantize first 16 elements (scale0) + dequantize_q6_k(as_ushort4(q4x16), qh_first16, reg_a, scale0); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q6_k(as_ushort4(q4x16), qh_second16, reg_a, scale1); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..13d79f2526f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -0,0 +1,151 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q4_k_to_fp32_packed8(ushort2 q4x8, float scale, float minv) { + float8 fp32x8; + fp32x8.s0 = (q4x8.s0 & 0x000F) * scale - minv; + fp32x8.s1 = ((q4x8.s0 & 0x00F0) >> 4) * scale - minv; + fp32x8.s2 = ((q4x8.s0 & 0x0F00) >> 8) * scale - minv; + fp32x8.s3 = ((q4x8.s0 & 0xF000) >> 12) * scale - minv; + fp32x8.s4 = (q4x8.s1 & 0x000F) * scale - minv; + fp32x8.s5 = ((q4x8.s1 & 0x00F0) >> 4) * scale - minv; + fp32x8.s6 = ((q4x8.s1 & 0x0F00) >> 8) * scale - minv; + fp32x8.s7 = ((q4x8.s1 & 0xF000) >> 12) * scale - minv; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_k_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s0), scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s1), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s2), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s3), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..f128d44340a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -0,0 +1,156 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q5_k_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_k_f32_ns( + __global uint * src0_q, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // sub_block index = sb * 8 + j + uint expert_qh_offset = expert_id * num_superblocks * 8 * ne01; + uchar4 regQh = as_uchar4(src0_qh[expert_qh_offset + (sb * 8 + j) * ne01 + i01]); + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..526e609dc3a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -0,0 +1,137 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q6_k_to_fp32_packed8(ushort2 ql8, ushort qh8, float d_scale) { + float8 fp32x8; + fp32x8.s0 = ((float)(( ql8.s0 & 0x000F) | ((uint)((qh8 ) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s1 = ((float)((( ql8.s0 >> 4) & 0x000F) | ((uint)((qh8 >> 2) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s2 = ((float)((( ql8.s0 >> 8) & 0x000F) | ((uint)((qh8 >> 4) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s3 = ((float)((( ql8.s0 >> 12)& 0x000F) | ((uint)((qh8 >> 6) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s4 = ((float)(( ql8.s1 & 0x000F) | ((uint)((qh8 >> 8) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s5 = ((float)((( ql8.s1 >> 4) & 0x000F) | ((uint)((qh8 >>10) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s6 = ((float)((( ql8.s1 >> 8) & 0x000F) | ((uint)((qh8 >>12) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s7 = ((float)((( ql8.s1 >> 12)& 0x000F) | ((uint)((qh8 >>14) & 0x3) << 4)) - 32.f) * d_scale; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q6_k_f32_ns( + __global uint * src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; // 8 sub-blocks of 32 per super-block + int scales_per_row = num_superblocks * 16; + + // Expert offsets in the transposed noshuffle layout + uint expert_ql_offset = expert_id * (ne00 / 8) * ne01; // 32 uints per super-block + uint expert_qh_offset = expert_id * (ne00 / 16) * ne01; // 16 uints per super-block + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; // super-block index + uint j = ib % 8; // 32-element group within super-block + + // Load d for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + + // Load 2 sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + // Load 4 uints of ql + uint ql_base = expert_ql_offset + (ib * 4) * ne01 + i01; + uint4 regQL; + regQL.s0 = src0_ql[ql_base]; + regQL.s1 = src0_ql[ql_base + ne01]; + regQL.s2 = src0_ql[ql_base + ne01 * 2]; + regQL.s3 = src0_ql[ql_base + ne01 * 3]; + + // Load 2 uints of qh + uint qh_base = expert_qh_offset + (ib * 2) * ne01 + i01; + uint2 regQH; + regQH.s0 = src0_qh[qh_base]; + regQH.s1 = src0_qh[qh_base + ne01]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s0), (ushort)(regQH.s0 & 0xFFFF), scale0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s1), (ushort)(regQH.s0 >> 16), scale0); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s2), (ushort)(regQH.s1 & 0xFFFF), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s3), (ushort)(regQH.s1 >> 16), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} From 89f3135ab19879dfaee0eeb4129abb72f45d2cdb Mon Sep 17 00:00:00 2001 From: ravel7524 <58877666+ravel7524@users.noreply.github.com> Date: Wed, 20 May 2026 03:52:21 +0200 Subject: [PATCH 29/55] ggml-cuda: tune RDNA3 Q6_K MMVQ nwarps (llama/23349) --- ggml/src/ggml-cuda/mmvq.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index da48f313a38..73a0991e206 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: + return 8; case GGML_TYPE_Q6_K: + return 2; case GGML_TYPE_IQ4_NL: return 8; default: From 34d3c6b96b5b165bab865ed935a4b21afacd50ef Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 20 May 2026 09:42:00 +0300 Subject: [PATCH 30/55] metal : optimize pad + cpy (llama/23354) * metal : optimize pad * metal : optinmize cpy * cont : better row packing in threadgroup --- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 17 ++- ggml/src/ggml-metal/ggml-metal.metal | 128 ++++++++++++---------- 3 files changed, 90 insertions(+), 63 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e288a27f992..ba006d9b31a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l char base[256]; char name[256]; - snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + // note: this is slower + //const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0; + const bool is_c4 = false; + + snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a114391c2e8..8506000b6c0 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -816,9 +816,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const int nth = MIN(args.ne00, nth_max); - const int nk0 = (args.ne00 + nth - 1)/nth; ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); @@ -1863,7 +1861,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nk0 = ne00/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne01, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1874,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } @@ -4039,14 +4037,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); - const int nth = std::min(1024, ne0); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const int nth = MIN(args.ne0, nth_max); + const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel! ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f6ffb2b3a1c..4cf9dbea946 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - if (K > 1u) { + if (K > 1) { const int target_slot = (int)t - shift; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; @@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl( } } - if (K == 1u) { + if (K == 1) { device float * dst_state = (device float *) (dst) + attn_size + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; @@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32( for (int64_t sx = x_min; sx < x_max; ++sx) { const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); const float w = wx * wy; - const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); sum += (*src_ptr) * w; wsum += w; } @@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32( const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; - const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); sum += (*src_ptr) * wx * wy; } } @@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32( } } -kernel void kernel_pad_f32( +template +kernel void kernel_pad_impl( constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t k0 = tgpig.x/args.ne1; + const int32_t i1 = tgpig.x - k0*args.ne1; - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int32_t i03 = i3; + const int32_t i02 = i2; + const int32_t i01 = i1; - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; - - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); - device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - if (i0 < args.ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } + for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) { + const int32_t i0 = k0*1024 + tpitg.x + l0; + if (i0 >= args.ne0) { + break; } - return; - } - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } } } +typedef decltype(kernel_pad_impl) kernel_pad_t; + +template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl; +template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl; + +// TODO: this is slow - optimize kernel void kernel_pad_reflect_1d_f32( constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, @@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; break; @@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; + const int32_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); quantize_func(src, dst_data[i00]); @@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; From a34c0240df9677e8f8a72e44d1eabfb9c82e976e Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Wed, 20 May 2026 13:59:02 +0200 Subject: [PATCH 31/55] Programmatic Dependent Launch (PDL) for more performance on newer NVIDIA GPUs (Hopper+) (llama/22522) * Adds initial PDL setup. * Adds PDL barriers based on simple heuristic: place "sync" before first input pointer access, and "launch" after last write, e.g. to tensors like dst. * Further optimization pass of the first half of kernels * Optimized PDL barriers for the second batch of kernels * Further refinements after rebase. * Moves pdl logic to separate function, removes some whitespace * Strips post-hoc PDL logic * Adds stream capture PDL setup. Enrolls quantize_q8_1 to leverage pdl to overlap execution with previous kernels * Enrolls mul_mat_vec_q, rms_norm_f32 and k_bin_bcast (partly) into PDL * Enrolls mmvf, rope, set-rows and topk kernels for gpt-oss into PDL * Introduce ggml_cuda_kernel_launch, to abstract away cudaLaunchKernelEx, to enable hip/musa compatibility * Enrolls cpy_scalar_contiguous, k_get_rows_float and rms_norm_f32 * Enrolls flash_attn_combine_results * Fix: Drops needless and broken check of CUDA arch for PDL. PDL either works or is without effect. * Enrolls flash-attention kernels to pdl * Fix: inlines ggml_cuda_kernel_launch, and uses perfect forwarding for kernels args. This fixes PDL. * Perf: Enrolls k_bin_bcast variadic template invocation into PDL, via and template alias and template expansion * Enrolls all remaining kernels for qwen3-coder-next into PDL * Remove all PDL LC calls to create a baseline * Added LC according to internal guidance and tested kernel performance. * Enrols missing qwen3-5 kernels passively into PDL. * Kernel optimizations (LC signals) for qwen3.5 * Enrolls ssm-scan kernels into PDL * Adds GGML_CUDA_PDL command line option to toggle PDL. * Fix: Ada and lower compilation by guarding PDL calls correctly * Cleanup: Removes commented out GGML_CUDA_PDL_LC * Cleanup: Removes experimental comments * Adds 90-virtual to build script so that Hopper GPUs can leverage PDL. * Adds stricter checks to enable PDL, adds env-check to disable it, and removes now superfluous compile option to enable PDL. * Fix: Correct PDL en/disablement based on device-side arch check. Host side check is UB. Required moving from macros to inlined functions * Fix: default-disable PDL. Enable by setting GGML_CUDA_ENABLE_PDL=1 * Enable PDL by default for Hopper+ devices * Enrolls softcap_f32 and two flash_attn kernels into PDL. * Improves flash attn PDL barrier placement * Fix: Perf regression on ada; excludes ada and below from PDL launches * Improves some sync barrier placements * Drops superfluous constructor * Adds #endif guard comments * Reverts experimental change to top-k-moe.cu, which moved expensive allocations in front of the PDL barrier. It did not have a meaningful impact. * Exchanges GGML_CUDA_DISABLE_PDL with GGML_CUDA_PDL. IFF GGML_CUDA_PDL=0 PDL is disabled * Revert "Drops superfluous constructor". Adds const to remaining arguments This reverts commit 12b1d250da0089ae02a9bb71bbb3fd6d70f6f2f1. * Cleanup: Removes and fixes some comments and whitespace * Clarifies comment of sync-barrier position * Relocates and refactors PDL launch functions and accessories * Adds error checking to the regular kernel launch path * Drops "auto" in favor of "ggml_cuda_kernel_params" * Adds "const" to ggml_cuda_kernel_launch_params * [Whitespace] Adds final newline to common.cuh to make editorconfig CI job happy --- ggml/src/ggml-cuda/CMakeLists.txt | 3 +- ggml/src/ggml-cuda/binbcast.cu | 32 +++++----- ggml/src/ggml-cuda/common.cuh | 86 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/concat.cu | 5 +- ggml/src/ggml-cuda/cpy.cu | 20 +++++-- ggml/src/ggml-cuda/fattn-common.cuh | 28 +++++---- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 1 + ggml/src/ggml-cuda/fattn-tile.cuh | 2 + ggml/src/ggml-cuda/fattn-vec.cuh | 3 + ggml/src/ggml-cuda/fattn-wmma-f16.cu | 1 + ggml/src/ggml-cuda/gated_delta_net.cu | 11 ++-- ggml/src/ggml-cuda/getrows.cu | 7 ++- ggml/src/ggml-cuda/mean.cu | 6 +- ggml/src/ggml-cuda/mmvf.cu | 12 ++-- ggml/src/ggml-cuda/mmvq.cu | 11 ++-- ggml/src/ggml-cuda/norm.cu | 46 ++++++++++---- ggml/src/ggml-cuda/quantize.cu | 7 ++- ggml/src/ggml-cuda/reduce_rows.cuh | 2 + ggml/src/ggml-cuda/rope.cu | 15 +++-- ggml/src/ggml-cuda/scale.cu | 5 +- ggml/src/ggml-cuda/set-rows.cu | 11 +++- ggml/src/ggml-cuda/softcap.cu | 5 +- ggml/src/ggml-cuda/ssm-conv.cu | 8 ++- ggml/src/ggml-cuda/ssm-scan.cu | 27 +++++---- ggml/src/ggml-cuda/sumrows.cu | 12 ++-- ggml/src/ggml-cuda/topk-moe.cu | 47 ++++++++------- ggml/src/ggml-cuda/unary.cu | 10 +++- 27 files changed, 310 insertions(+), 113 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index b54d4a6b107..d3953eee962 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND) # 80 == Ampere, asynchronous data loading, faster tensor core instructions # 86 == RTX 3000, needs CUDA v11.1 # 89 == RTX 4000, needs CUDA v11.8 + # 90 == Hopper H100/200, needs CUDA v11.8 # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores # # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run @@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND) list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual) endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index adb4d5f0cb9..c25f42b32bb 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -2,6 +2,9 @@ #include #include +template +using type_for_index = T; + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const int s12, const int s13, src1_ptrs... src1s) { + ggml_cuda_pdl_lc(); const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); @@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; + ggml_cuda_pdl_sync(); for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); @@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); + ggml_cuda_pdl_sync(); float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); @@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast_unravel<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast_unravel...>, launch_params, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, - ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, - s00, s01, s02, s03, - s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast...>, launch_params, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, - s00 ,s01, s02, s03, - s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast<<>>( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, s00, s01, s02, s03, - s10, s11, s12, s13); + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } } @@ -333,6 +328,7 @@ static __global__ void k_repeat_back( } T sum = 0; + ggml_cuda_pdl_sync(); for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10817505d9f..9c73fe7e6fa 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -5,6 +5,7 @@ #include "ggml-cuda.h" #include +#include #include #if defined(GGML_USE_HIP) @@ -27,6 +28,7 @@ #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -50,6 +52,7 @@ #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_HOPPER 900 // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms #define GGML_CUDA_CC_BLACKWELL 1200 @@ -107,6 +110,24 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA. +// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +# define GGML_CUDA_USE_PDL +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 + +static __device__ __forceinline__ void ggml_cuda_pdl_sync() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaGridDependencySynchronize(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + +static __device__ __forceinline__ void ggml_cuda_pdl_lc() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaTriggerProgrammaticLaunchCompletion(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + #ifdef __CUDA_ARCH_LIST__ constexpr bool ggml_cuda_has_arch_impl(int) { return false; @@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) + #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) static const char * cublas_get_error_str(const cublasStatus_t err) { return cublasGetStatusString(err); @@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device { const void * gate_bias = nullptr; ggml_glu_op glu_op; }; + +struct ggml_cuda_kernel_launch_params { + dim3 block_nums; + dim3 block_dims; + size_t shmem; + cudaStream_t stream; + + // size_t shmem + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {} + + // Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings. + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {} +}; + +#if defined(GGML_CUDA_USE_PDL) +struct ggml_cuda_pdl_config { + cudaLaunchAttribute attr; + cudaLaunchConfig_t cfg; + + ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) { + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = 1; + + cfg = {}; + cfg.gridDim = params.block_nums; + cfg.blockDim = params.block_dims; + cfg.dynamicSmemBytes = params.shmem; + cfg.stream = params.stream; + cfg.attrs = &attr; + cfg.numAttrs = 1; + } + + // Delete due to &attr + ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; + +}; +#endif //defined(GGML_CUDA_USE_PDL) + + +template +static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { +#if defined(GGML_CUDA_USE_PDL) + + static const bool env_pdl_enabled = []() { + const char * env = getenv("GGML_CUDA_PDL"); + return env == nullptr || std::atoi(env) != 0; + }(); + + if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) { + auto pdl_cfg = ggml_cuda_pdl_config(launch_params); + + CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward(args)... )); + return; + } +#endif //defined(GGML_CUDA_USE_PDL) + + kernel<<>>(std::forward(args)... ); + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 102f944f924..adba4d522a4 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont const int64_t n = ne0 * ne1 * ne2; + ggml_cuda_pdl_sync(); for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { if constexpr (dim == 0) { const int64_t row = i / ne0; @@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x, const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { - concat_f32_cont<0> - <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index d208acf2d5f..121472ec228 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13) { + ggml_cuda_pdl_lc(); const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + ggml_cuda_pdl_sync(); cpy_1(cx + x_offset, cdst + dst_offset); } @@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; int cur_tile_buf = 0; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const const src_t * x = (const src_t *) cx; dst_t * dst = (dst_t *) cdst; + ggml_cuda_pdl_sync(); dst[i] = ggml_cuda_cast(x[i]); } @@ -192,8 +198,8 @@ cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar_contiguous<<>> - (cx, cdst, ne); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_contiguous, launch_params, cx, cdst, ne); } template @@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda( GGML_ASSERT(grid_z < USHRT_MAX); dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_scalar_transpose<<>> - (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_transpose, launch_params, + cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar>, launch_params, + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb5238946..debcb6e5447 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -636,6 +636,7 @@ static __global__ void flash_attn_mask_to_KV_max( if (tid < WARP_SIZE) { buf_iw[tid] = 1; } + ggml_cuda_pdl_sync(); __syncthreads(); int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; @@ -687,6 +688,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j_z, const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; + ggml_cuda_pdl_lc(); const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -718,6 +720,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + ggml_cuda_pdl_sync(); // Load the partial result that needs a fixup float dst_val = *dst; float max_val; @@ -809,6 +812,7 @@ static __global__ void flash_attn_stream_k_fixup_general( float dst_val = 0.0f; float max_val = 0.0f; float rowsum = 0.0f; + ggml_cuda_pdl_sync(); { dst_val = *dst; @@ -867,6 +871,7 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { + ggml_cuda_pdl_lc(); // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -890,6 +895,7 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; + ggml_cuda_pdl_sync(); for (int i = tid; i < 2*parallel_blocks; i += D) { ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } @@ -1146,7 +1152,9 @@ void launch_fattn( const uint3 ne01 = init_fastdiv_values(Q->ne[1]); GGML_ASSERT(block_dim.x % warp_size == 0); - fattn_kernel<<>>( + + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, @@ -1176,9 +1184,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; - flash_attn_stream_k_fixup_uniform - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_uniform, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, gqa_ratio, bpt, fd0, fd1, fd2); } else if (ntiles_dst % blocks_num.x != 0) { @@ -1193,9 +1201,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup_general - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_general, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], gqa_ratio, total_work, fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } @@ -1204,9 +1212,9 @@ void launch_fattn( const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream); + ggml_cuda_kernel_launch(flash_attn_combine_results, launch_params, + dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index a25e912c4d2..4871b90df86 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1724,6 +1724,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_sync(); // TODO optimize placement #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 7b0a5e5cf49..fac76f13593 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -894,6 +894,8 @@ static __global__ void flash_attn_tile( } float KQ_sum[cpw] = {0.0f}; + ggml_cuda_pdl_sync(); + // Load Q data, convert to FP16 if fast: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f0bd42a5761..b0a6cf67f1a 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -136,6 +137,8 @@ static __global__ void flash_attn_ext_vec( #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + + ggml_cuda_pdl_sync(); if constexpr (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f19defbff93..4b6f6501094 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -86,6 +86,7 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + ggml_cuda_pdl_sync(); const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index b4c9845e7a7..018d5d37d47 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,4 +1,5 @@ #include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) @@ -53,6 +54,7 @@ gated_delta_net_cuda(const float * q, float s_shard[rows_per_lane]; // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + ggml_cuda_pdl_sync(); #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; @@ -189,28 +191,29 @@ static void launch_gated_delta_net( int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 36b840e8148..457b695eb2a 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -11,6 +11,7 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -48,6 +49,8 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_lc(); + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -83,6 +86,7 @@ static __global__ void k_get_rows_back_float( float sum = 0.0f; + ggml_cuda_pdl_sync(); for (int64_t i = 0; i < nrows_grad; ++i) { if (rows[i] != dst_row) { continue; @@ -156,7 +160,8 @@ static void get_rows_cuda_float( GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); const uint3 ne12_fdv = init_fastdiv_values(ne12); - k_get_rows_float<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{block_nums, block_dims, 0, stream}; + ggml_cuda_kernel_launch(k_get_rows_float, launch_params, src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 49af5389957..a8f6046e46d 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -67,9 +67,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132 if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index d9147202429..09d95f309b4 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -21,6 +21,7 @@ static __global__ void mul_mat_vec_f( int channel_y; int sample_dst; + ggml_cuda_pdl_sync(); if constexpr (is_multi_token_id) { // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case token_idx = blockIdx.z; @@ -298,6 +299,7 @@ static __global__ void mul_mat_vec_f( static_assert(std::is_same_v, "unsupported type"); } + ggml_cuda_pdl_lc(); #pragma unroll for (int j = 0; j < ncols_dst; ++j) { sumf[j] = warp_reduce_sum(sumf[j]); @@ -382,11 +384,13 @@ static void mul_mat_vec_f_switch_fusion( const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, nbytes_shared, stream}; + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -395,8 +399,8 @@ static void mul_mat_vec_f_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 73a0991e206..13b8b855282 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -424,6 +424,7 @@ static __global__ void mul_mat_vec_q( uint32_t channel_y; uint32_t sample_dst; + ggml_cuda_pdl_sync(); channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; sample_dst = blockIdx.z; @@ -683,8 +684,9 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -693,8 +695,9 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index ef98f675aa7..09d9f3a7d62 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -18,6 +18,7 @@ static __global__ void norm_f32( float2 mean_var = make_float2(0.0f, 0.0f); + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; mean_var.x += xi; @@ -46,6 +47,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int j = start; j < end; j += block_size) { tmp += x[j]; } @@ -95,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x, const uint3 add_nrows_packed = make_uint3(0, 0, 0), const uint3 add_nchannels_packed = make_uint3(0, 0, 0), const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { + ggml_cuda_pdl_lc(); const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -124,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -163,6 +167,7 @@ static __global__ void rms_norm_back_f32( float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xfi = xf[col]; sum_xx += xfi * xfi; @@ -253,6 +258,7 @@ static __global__ void l2_norm_f32( float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -261,6 +267,7 @@ static __global__ void l2_norm_f32( // sum up partial sums extern __shared__ float s_sum[]; tmp = block_reduce(tmp, s_sum); + ggml_cuda_pdl_lc(); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -300,10 +307,19 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params, + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } @@ -346,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } else { const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); @@ -367,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -399,10 +423,12 @@ static void l2_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 52f664719ae..49516965cad 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -6,6 +6,7 @@ static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { + ggml_cuda_pdl_lc(); const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { @@ -28,6 +29,7 @@ static __global__ void quantize_q8_1( const int64_t ib = i_cont / QK8_1; // block index const int64_t iqs = i_cont % QK8_1; // quant index + ggml_cuda_pdl_sync(); const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -196,6 +198,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -288,6 +291,7 @@ static __global__ void quantize_mmq_q8_1( const int64_t i3 = blockIdx.z / ne2; const int64_t i00 = i0; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -378,7 +382,8 @@ void quantize_row_q8_1_cuda( const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream); + ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index de240fd4413..5895d3bf8e5 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -10,6 +10,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r const int num_unroll = 8; float temp[num_unroll]; float sum_temp[num_unroll] = { 0.0f }; + + ggml_cuda_pdl_sync(); for (int i = col; i < ncols;) { for (int j = 0; j < num_unroll; ++j) { if (i < ncols) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 45a49a5dc2a..e20a5cb6bed 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -134,6 +134,7 @@ static __global__ void rope_neox(const T * x, const float * freq_factors, const int64_t * row_indices, const int set_rows_stride) { + ggml_cuda_pdl_lc(); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne00) { @@ -148,6 +149,7 @@ static __global__ void rope_neox(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. @@ -216,6 +218,7 @@ static __global__ void rope_multi(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; @@ -300,6 +303,7 @@ static __global__ void rope_vision(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); const int sect_dims = sections.v[0] + sections.v[1]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; @@ -399,13 +403,14 @@ static void rope_neox_cuda(const T * x, const dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream}; if (freq_factors == nullptr) { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } @@ -443,11 +448,13 @@ static void rope_multi_cuda(const T * x, const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 0ddeff6a175..7b2e59a4383 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -3,9 +3,11 @@ #define MAX_GRIDDIM_X 0x7FFFFFFF static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + ggml_cuda_pdl_lc(); int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + ggml_cuda_pdl_sync(); for (int64_t i = tid; i < nelements; i += stride) { dst[i] = scale * x[i] + bias; } @@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, nelements); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa5..e14f96b824c 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; @@ -157,7 +158,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + ggml_cuda_pdl_lc(); const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; @@ -203,9 +206,11 @@ static void set_rows_cuda( const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); - k_set_rows<<>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, - s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, - ne11_fd, ne12_fd); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream); + ggml_cuda_kernel_launch(k_set_rows, launch_params, + src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, + s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, + ne11_fd, ne12_fd); } } diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 40dfe45d65c..9f0fa1051cf 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -1,18 +1,21 @@ #include "softcap.cuh" static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = tanhf(scale * x[i]) * softcap; } static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; - softcap_f32<<>>(x, dst, scale, softcap, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(softcap_f32, launch_params, x, dst, scale, softcap, k); } // fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4c4daf85dc6..48787b4b890 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "ssm-conv.cuh" #include "unary.cuh" @@ -7,6 +8,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { + ggml_cuda_pdl_lc(); GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; @@ -23,6 +25,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float float x[d_conv] = { 0.0f }; float w[d_conv] = { 0.0f }; + ggml_cuda_pdl_sync(); #pragma unroll for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; @@ -128,8 +131,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_conv_f32, launch_params, src0, src1, bias, src0_nb0, src0_nb1, + src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index c1d4e2bc8df..412980376ac 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -26,6 +26,7 @@ __global__ void __launch_bounds__(splitD, 1) const int64_t s_off, const int64_t d_inner, const int64_t L_param) { const size_t L = L_template == 0 ? L_param : L_template; + ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); @@ -135,6 +136,7 @@ __global__ void __launch_bounds__(d_state, 1) const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + ggml_cuda_pdl_sync(); // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); @@ -206,7 +208,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<128/WARP_SIZE, 128>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -215,7 +218,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<256/WARP_SIZE, 256>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -231,58 +235,59 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); switch (n_tok) { case 1: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 2: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 3: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 4: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 5: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 6: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 7: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 8: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; default: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 4025771aadb..0003658ca95 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -7,10 +7,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int const dim3 block_nums(nrows, 1, 1); if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } } @@ -34,10 +36,12 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if ((nrows / nsm) < 2) { // Increase num threads to 512 for small nrows to better hide the latency const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { // Enough active SMs to hide latency, use smaller blocks to allow better scheduling const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 3020e5c7433..da20c9aab7c 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -105,6 +105,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * wt[i] = -INFINITY; } + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; @@ -161,6 +162,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * output_weights[i] = 0.f; } + ggml_cuda_pdl_lc(); for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; @@ -271,51 +273,52 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (n_expert) { case 1: - topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<1, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<2, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<4, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<8, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<16, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<32, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<64, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<128, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 576: - topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<576, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2aeba26f414..4cb805fa601 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -116,19 +116,22 @@ static __device__ __forceinline__ float op_trunc(float x) { template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = (T)op((float)x[i]); } template static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<<>>(x, dst, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_op_kernel, launch_params, x, dst, k); } template @@ -258,6 +261,7 @@ void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { template static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + ggml_cuda_pdl_lc(); const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -268,13 +272,15 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t j0 = (i / n) * o0 + (i % n); const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + ggml_cuda_pdl_sync(); dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } template static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_gated_op_kernel, launch_params, x, g, dst, k, n, o0, o1); } template From 64bdb605e062e9b99b3ef5ca8b85f712ee00cfbb Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 20 May 2026 07:39:01 -0700 Subject: [PATCH 32/55] hexagon: HMX quantized matmul rework (llama/23368) * hmx-mm: update debug logging in hmx-mm * hmx-mm: update dequant logic to use HVX_vector_x2/4 * hmx-mm: remove non-pipelined version of the quantize matmul It seems that we don't reall need non-pipelined version * hmx-mm: use activation depth mode and update naming Co-authored-by: Kim-Chyan Gan * hex-mm: minor hmx matmul naming updates * hmx-mm: remove unused vars * snapdragon: scripts bump default ubatch-size to 1K * hexagon: combine HMX and power and clock settings into a single set_power call * hmx-mm: remove leftover of the scale repl helper * hexagon: fix editconf error --------- Co-authored-by: Kim-Chyan Gan --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 642 ++++++--------------- ggml/src/ggml-hexagon/htp/hmx-ops.h | 13 +- ggml/src/ggml-hexagon/htp/main.c | 38 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 10 +- 4 files changed, 196 insertions(+), 507 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index e05ccfd5fc7..3ef0bcdb26d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32 // Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using // full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. -// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. -static inline void dequantize_x4x2_q4_0_x4groups_hvx( +// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. +static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt, - HVX_Vector out[4]) { + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); @@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - volatile HVX_Vector vscale = hvx_vmemu(scales_4); - + HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - out[0] = v_lo; // group0 already in [0:63] - out[1] = v_hi; // group2 already in [0:63] + HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ + v_hi /* group2 already in [0:63] */ }; + return r; } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. @@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed } // Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, +static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, bool upper_nibbles, int sub_blk_base, const HVX_Vector vlut_cvt, - mxfp4_scales_t scales, - HVX_Vector out[4]) { + mxfp4_scales_t scales) { HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; @@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12 v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - out[0] = v_lo; - out[1] = Q6_V_vror_VR(v_lo, 64); - out[2] = v_hi; - out[3] = Q6_V_vror_VR(v_hi, 64); + HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; + return r; } // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. @@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - HVX_Vector v0[2]; const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - HVX_Vector v0[4], v1[4]; - dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + HVX_Vector_x4 dv0, dv1; + dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); if (row1 < n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { - v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); } for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; @@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); } -// - -#define FALLBACK_TO_STANDARD 1 - // C += AB static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, @@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); } - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } -} - -static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, - float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - const size_t vtcm_budget = ctx->vtcm_size; - - const size_t K_BLOCK_SIZE = 1024; - - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; - } - - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } - - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); - return -1; - } - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } - } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); - - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); - - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); - - TIMER_START(fetch); - // fetch activation block into VTCM - { - const float *activation_block = x + mr * k + kk; - - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); - } - - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; - uint8_t *dst = vtcm_scratch0; - const uint8_t *src = w + nc * row_stride; - const size_t n_rows = n_blk_sz; - const size_t src_stride = row_stride; - const size_t dst_stride = sub_row_stride; - const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - const size_t scale_off = full_qrow + blk_start * scale_blk_size; - const size_t scale_width = nb_sub * scale_blk_size; - - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); - } - TIMER_STOP(fetch); - - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); - - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); - - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); - } - TIMER_STOP(core); + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } - // store output block - { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); - } + Q6_mxmem_AR_after_hf(accum_tile, 0); } } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); -#endif - return 0; } -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { return -1; } - // for large m, k (e.g. prefill FFN Down), use out-stationary version - if (m >= 128 && k > n && n > 1024) { - int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); - if (rc != FALLBACK_TO_STANDARD) { - return rc; // 0 success, -1 error - } - FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); - // fall through to standard path - } - size_t row_stride = get_x4x2_row_stride(weight_type, k); if (row_stride == 0) { return -1; } - FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - // Only pays off when the chunker yields >=2 n-chunks, so the main loop can - // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for - // double-buffered output and the worker-dispatch overhead are pure loss. - // Try pipeline costs first; fall back to sequential if the layout collapses - // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. - const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - const size_t seq_per_mn = sizeof(__fp16); // O x 1 - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - bool use_pipeline = false; - - if (m >= 128) { - size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && - hmx_ceil_div((size_t) n, nc) >= 2) { - m_chunk_n_rows = mc; - n_chunk_n_cols = nc; - vtcm_used = used; - use_pipeline = true; - } - } + const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + return -1; } - // Compute precise buffer sizes per execution path - const size_t weight_area_size = hex_align_up( - n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up( - m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); size_t scratch0_size, scratch1_size, scratch2_size; - if (use_pipeline) { - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 - } else { - scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 - scratch1_size = scratch0_size; // x4x2 DMA buf 1 - scratch2_size = 0; // unused - } + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, weight_type, use_pipeline, - m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); TIMER_DEFINE(weight_load); @@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_DEFINE(total); TIMER_START(total); - FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", - use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - if (!use_pipeline) { - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } - TIMER_STOP(activation_load); + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - - const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); - } + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - // Dequant + vscatter writes directly to [K, N] transposed tiles. - // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } else { - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers - - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); - } + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); } + } - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); - } + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); } } - - hmx_queue_suspend(ctx->hmx_queue); } + hmx_queue_suspend(ctx->hmx_queue); + TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; } -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; } -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { const int r2 = hmx_matmul_batch_r2(params); const int r3 = hmx_matmul_batch_r3(params); @@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_ (size_t) (dst_b3 / r3) * params->src0_nb3); } -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (const float *) ((const uint8_t *) params->activation + (size_t) dst_b2 * params->src1_nb2 + (size_t) dst_b3 * params->src1_nb3); } -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (float *) ((uint8_t *) params->dst + (size_t) dst_b2 * params->dst_nb2 + (size_t) dst_b3 * params->dst_nb3); } -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { +static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_f16_f32_batched_params_t *params) { int ret = 0; for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); } } return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } if (!params->m || !params->k || !params->n) { return -1; } if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } @@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (group_size <= 1) { FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } // Grouped path: reuse interleaved weight across all q_heads sharing a @@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads @@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 @@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu // -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index 1c78ffadd1c..f114edb822f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -33,14 +33,14 @@ typedef struct { size_t src1_nb3; size_t dst_nb2; size_t dst_nb3; -} hmx_matmul_w16a32_batched_params_t; +} hmx_matmul_f16_f32_batched_params_t; // HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output // act_stride: activation row stride in elements (= k for contiguous, or // nb[1]/sizeof(float) for permuted tensors like attention Q). // weight_stride: weight row stride in elements (= k for compact weights, or // nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const __fp16 *permuted_weight, @@ -48,13 +48,12 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, int act_stride, int weight_stride); -// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batched F16 wrapper over hmx_mat_mul_f16_f32. // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params); +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, +// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 8e54536f619..e8619388478 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -87,35 +87,37 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 { - // Power on HMX + // Power on HMX and set HMX clock HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX; - request.hmx.power_up = TRUE; - FARF(ALWAYS, "Powering HMX on\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_power = TRUE; + request.hmx_v2.power_up = TRUE; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "Error setting HMX clock."); return err; } } - -#if __HVX_ARCH__ >= 75 +#else { - // Set HMX clock + // Power on HMX HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX_v2; - request.hmx_v2.set_clock = TRUE; - request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; - FARF(ALWAYS, "Setting HMX clock\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX; + request.hmx.power_up = TRUE; + FARF(ALWAYS, "Powering HMX on\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "Error powering on HMX."); return err; } } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 2461ae617fa..46fc5602dc9 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2995,7 +2995,6 @@ int op_matmul(struct htp_ops_context * octx) { // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { return op_matmul_hvx(octx); } @@ -3020,7 +3019,7 @@ int op_matmul(struct htp_ops_context * octx) { if (src0->type == HTP_TYPE_F16) { if (is_batched) { - hmx_matmul_w16a32_batched_params_t batch_params = { + hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, @@ -3041,15 +3040,14 @@ int op_matmul(struct htp_ops_context * octx) { .dst_nb2 = dst->nb[2], .dst_nb3 = dst->nb[3], }; - ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + ret = hmx_matmul_f16_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, m_total, k, n, act_stride, wgt_stride); } } else { - ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, m_total, k, n, (int) src0->type); } From e9b7cc8c55dbc854dd1131ea623ea6b4937900aa Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Wed, 20 May 2026 17:15:13 +0200 Subject: [PATCH 33/55] vulkan: optimize operations in the IM2COL shader (llama/22685) * vulkan: optimize operations in the IM2COL shader * Add comments and improve the code formatting --- .../ggml-vulkan/vulkan-shaders/im2col.comp | 73 +++++++++++++++---- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index ba4c2103f0c..f4130d223b1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -44,36 +44,81 @@ void im2col(const uint ow, const uint z_idx) { const uint KHKW = p.KH * p.KW; + // Precompute base input coordinates + const int base_iw = int(ow * p.s0) - p.p0; + const int base_ih = int(oh * p.s1) - p.p1; + + // Precompute step deltas + const uint delta_ic = BLOCK_SIZE / KHKW; + const uint delta_rem = BLOCK_SIZE % KHKW; + + const uint delta_ky = delta_rem / p.KW; + const uint delta_kx = delta_rem % p.KW; + + const uint delta_ic_offset = delta_ic * p.offset_delta; + + // If using BDA mode, precompute the base pointer and step size +#if BDA + const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row; + const uint bda_step = D_SIZE * BLOCK_SIZE; +#endif + uint wg_x = gl_WorkGroupID.x; do { const uint wg_offset = wg_x * 512; - [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { - const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; + uint chw_idx = wg_offset + gidx; + + uint ic = chw_idx / KHKW; + uint rem = chw_idx % KHKW; + + uint ky = rem / p.KW; + uint kx = rem % p.KW; + uint ic_offset = src_batch + ic * p.offset_delta; + + // Initialize running pointer/index for the destination buffer +#if BDA + BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx; +#else + uint current_dst_idx = dst_row + chw_idx; +#endif + + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { if (chw_idx >= p.CHW) { return; } - const uint ic = chw_idx / KHKW; - const uint rem = chw_idx - ic * KHKW; - const uint ky = rem / p.KW; - const uint kx = rem - ky * p.KW; - - const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + const int iiw = base_iw + int(kx * p.d0); + const int iih = base_ih + int(ky * p.d1); A_TYPE val = A_TYPE(0); - if (iih < p.IH && iiw < p.IW) { - val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + if (uint(iih) < p.IH && uint(iiw) < p.IW) { + val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)]; } #if BDA - D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); - out_ptr.d = D_TYPE(val); + D_ptr(current_dst_addr).d = D_TYPE(val); + current_dst_addr += bda_step; #else - data_d[dst_row + chw_idx] = D_TYPE(val); + data_d[current_dst_idx] = D_TYPE(val); + current_dst_idx += BLOCK_SIZE; #endif + + chw_idx += BLOCK_SIZE; + ic_offset += delta_ic_offset; + kx += delta_kx; + ky += delta_ky; + + // Handle X axis wrap + uint kx_wrap = uint(kx >= p.KW); + kx -= kx_wrap * p.KW; + ky += kx_wrap; + + // Handle Y axis wrap + uint ky_wrap = uint(ky >= p.KH); + ky -= ky_wrap * p.KH; + ic_offset += ky_wrap * p.offset_delta; } wg_x += gl_NumWorkGroups.x; From 6ce303bc4ebc30d8c40191660b84a8316bc90430 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 20 May 2026 09:57:36 -0700 Subject: [PATCH 34/55] opencl: refactor backend initilization (llama/23318) * opencl: refactor initialization * opencl: refactor GPU identification * opencl: rename for consistency * opencl: cache global mem size in dev_ctx * opencl: adjust log level * opencl: load argsort and flash_attn kernels in supports_op * argsort kernel must be built for supports_op for querying the max workgroups * flash_attn kernel has many variants, only load them when needed --- ggml/src/ggml-opencl/ggml-opencl.cpp | 429 ++++++++++++++++----------- 1 file changed, 254 insertions(+), 175 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a3af8c2da41..5fc46f789ec 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -375,6 +375,11 @@ struct ggml_backend_opencl_device_context { ggml_backend_buffer_type buffer_type; cl_context context = nullptr; + + GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; + ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + + size_t global_mem_size = 0; }; // backend context @@ -384,6 +389,18 @@ struct ggml_backend_opencl_context { cl_device_id device; std::string device_name; + ggml_cl_version platform_version; + ggml_cl_version opencl_c_version; + + // argsort is loaded in supports_op because its availability depends on how + // many workgroups are allowed, which requires kernel compilation. + bool kernels_loaded_argsort = false; + // flash attn is loaded in supports_op because it contains multiple variants + // and takes time to compile, so we want to only compile it when needed. + bool kernels_loaded_flash_attn = false; + // rest of the kernels are currently always loaded in alloc_buffer. + bool kernels_loaded = false; + std::string driver_version; GPU_FAMILY gpu_family; @@ -781,6 +798,8 @@ struct ggml_backend_opencl_context { #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { + clFinish(queue); + ref_count--; if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING @@ -793,6 +812,9 @@ struct ggml_backend_opencl_context { // All registered devices with a default device in the front. static std::vector g_ggml_backend_opencl_devices; +// All device contexts associated with the devices above. +// The devices live as long as the process, so do the contexts. +static std::vector> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { std::ifstream ifs(path); @@ -836,12 +858,120 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } -static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { +static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // argsort + if (!backend_ctx->kernels_loaded_argsort) { + cl_int err; +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + backend_ctx->kernels_loaded_argsort = true; + } +} + +static void load_cl_kernels_flash_attn(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // flash_attn + if (!backend_ctx->kernels_loaded_flash_attn) { + cl_int err; + + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_f16 { + #include "flash_attn_f16.cl.h" + }; + const std::string kernel_src_f32 { + #include "flash_attn_f32.cl.h" + }; + const std::string kernel_src_f32_f16 { + #include "flash_attn_f32_f16.cl.h" + }; + #else + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); + #endif + + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { + { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, + {192, 192, 16, 16}, {256, 256, 16, 16}, + }; + + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { + const int dk = fa_dims[i].dk; + const int dv = fa_dims[i].dv; + const int bm = fa_dims[i].bm; + const int bn = fa_dims[i].bn; + std::string OPTS = compile_opts + + " -D DK=" + std::to_string(dk) + + " -D DV=" + std::to_string(dv) + + " -D BLOCK_M=" + std::to_string(bm) + + " -D BLOCK_N=" + std::to_string(bn); + + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); + cl_kernel k_f16, k_f16_q1; + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; + CL_CHECK(clReleaseProgram(prog_f16)); + + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); + cl_kernel k_f32, k_f32_q1; + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; + CL_CHECK(clReleaseProgram(prog_f32)); + + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); + cl_kernel k_f32_f16, k_f32_f16_q1; + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; + CL_CHECK(clReleaseProgram(prog_f32_f16)); + + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; + backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; + } + backend_ctx->kernels_loaded_flash_attn = true; + } + } +} + +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { + if (backend_ctx->kernels_loaded) { + return; + } + cl_int err; // compiler options for general kernels auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); std::string compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; @@ -1986,89 +2116,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // flash_attn - { - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_f16 { - #include "flash_attn_f16.cl.h" - }; - const std::string kernel_src_f32 { - #include "flash_attn_f32.cl.h" - }; - const std::string kernel_src_f32_f16 { - #include "flash_attn_f32_f16.cl.h" - }; - #else - const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); - const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); - const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); - #endif - - if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { - const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { - { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, - {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, - {192, 192, 16, 16}, {256, 256, 16, 16}, - }; - - for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { - const int dk = fa_dims[i].dk; - const int dv = fa_dims[i].dv; - const int bm = fa_dims[i].bm; - const int bn = fa_dims[i].bn; - std::string OPTS = compile_opts + - " -D DK=" + std::to_string(dk) + - " -D DV=" + std::to_string(dv) + - " -D BLOCK_M=" + std::to_string(bm) + - " -D BLOCK_N=" + std::to_string(bn); - - cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); - cl_kernel k_f16, k_f16_q1; - CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); - CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; - backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; - CL_CHECK(clReleaseProgram(prog_f16)); - - cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); - cl_kernel k_f32, k_f32_q1; - CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); - CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; - backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; - CL_CHECK(clReleaseProgram(prog_f32)); - - cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); - cl_kernel k_f32_f16, k_f32_f16_q1; - CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); - CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; - backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; - CL_CHECK(clReleaseProgram(prog_f32_f16)); - - backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; - backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; - } - GGML_LOG_CONT("."); - } - } - - // argsort - { -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "argsort.cl.h" - }; -#else - const std::string kernel_src = read_file("argsort.cl"); -#endif - backend_ctx->program_argsort_f32_i32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); - GGML_LOG_CONT("."); - } - // div { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3335,13 +3382,15 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); + backend_ctx->kernels_loaded = true; } // XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // XXX static bool initialized = false; // XXX static ggml_backend_opencl_context *backend_ctx = nullptr; -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev); +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev); namespace /* anonymous */ { extern struct ggml_backend_device_i ggml_backend_opencl_device_i; @@ -3554,13 +3603,13 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r /* .context = */ dev_ctx.get(), }); - if (!ggml_cl2_init(&found_devices.back())) { + if (!ggml_opencl_is_device_supported(&found_devices.back())) { found_devices.pop_back(); - GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); + GGML_LOG_WARN("ggml_opencl: drop unsupported device '%s'.\n", dev->name); continue; } - dev_ctx.release(); + g_ggml_backend_opencl_dev_ctxs.push_back(std::move(dev_ctx)); } if (found_devices.size()) { @@ -3577,8 +3626,79 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r return found_devices; } +// check if device should be accepted +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { + dev_ctx->gpu_family = GPU_FAMILY::ADRENO; + + // Usually device version contains the detailed device name + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); + if (dev_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); + } + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { + dev_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_WARN("ggml_opencl: unsupported GPU '%s'.\n", dev_ctx->device_name.c_str()); + dev_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return false; + } + + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + + // Check device OpenCL version, OpenCL 2.0 or above is required + ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, dev_ctx->device); + if (opencl_c_version.major < 2) { + GGML_LOG_WARN("ggml_opencl: OpenCL 2.0 or above is required\n"); + return false; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (dev_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_WARN("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return false; + } +#endif + + size_t ext_str_size; + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; + + // Check if ext_buffer contains cl_khr_fp16 + bool fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + if (!fp16_support) { + GGML_LOG_WARN("ggml_opencl: device does not support FP16\n"); + return false; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_WARN("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return false; + } + + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + return true; +} + // Initialize device if it is supported (returns nullptr if it is not). -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { GGML_ASSERT(dev); GGML_ASSERT(dev->context); @@ -3600,33 +3720,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // when the associated device is initialized backend_ctx->ref_count = 0; - if (strstr(dev_ctx->device_name.c_str(), "Adreno") || - strstr(dev_ctx->device_name.c_str(), "Qualcomm") || - strstr(dev_ctx->device_version.c_str(), "Adreno")) { - backend_ctx->gpu_family = GPU_FAMILY::ADRENO; - // Usually device version contains the detailed device name - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); - } - + backend_ctx->gpu_family = dev_ctx->gpu_family; + backend_ctx->adreno_gen = dev_ctx->adreno_gen; + if (backend_ctx->gpu_family == GPU_FAMILY::ADRENO) { // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; - } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { - backend_ctx->gpu_family = GPU_FAMILY::INTEL; - } else { - GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - return nullptr; - } - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { - GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " - "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); - return nullptr; } -#endif // Populate backend device name backend_ctx->device_name = dev_ctx->device_name; @@ -3635,13 +3734,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { cl_device_id device = backend_ctx->device; ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); - - // Check device OpenCL version, OpenCL 2.0 or above is required ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); - if (opencl_c_version.major < 2) { - GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); - return nullptr; - } + + backend_ctx->platform_version = platform_version; + backend_ctx->opencl_c_version = opencl_c_version; // Check driver version size_t driver_version_str_size; @@ -3664,34 +3760,21 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { char *ext_buffer = (char *)alloca(ext_str_size + 1); clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; - // fp16 is required - if (!backend_ctx->fp16_support) { - GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); - return nullptr; - } - - // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes - // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) - if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && - strstr(ext_buffer, "cl_intel_subgroups") == NULL) { - GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " - "(note that subgroups is an optional feature in OpenCL 3.0)\n"); - return nullptr; - } - cl_uint base_align_in_bits; CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); GGML_ASSERT(base_align_in_bits % 8u == 0); backend_ctx->alignment = base_align_in_bits / 8u; GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); - clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &backend_ctx->global_mem_size, NULL); + backend_ctx->global_mem_size = dev_ctx->global_mem_size; GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); @@ -3779,8 +3862,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); - // Load kernels - load_cl_kernels(backend_ctx.get(), opencl_c_version); + // delay kernel loading until the first buffer is created + // load_cl_kernels(backend_ctx.get()); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Allocate intermediate buffers and images @@ -3822,22 +3905,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { return dev_ctx->backend_ctx; } -static void ggml_cl2_free(ggml_backend_t backend) { +static void ggml_cl_free(ggml_backend_t backend) { ggml_backend_opencl_context * ctx = (ggml_backend_opencl_context *) backend->context; ctx->free(); - - // The CL context is shared by all backends, release it if all backends have been released - bool should_release_opencl = true; - for (auto device : g_ggml_backend_opencl_devices) { - ggml_backend_opencl_device_context * ctx_dev = (ggml_backend_opencl_device_context *) device.context; - if (ctx_dev->backend_ctx->ref_count > 0) { - should_release_opencl = false; - } - } - - if (should_release_opencl) { - CL_CHECK(clReleaseContext(ctx->context)); - } } #ifdef GGML_OPENCL_USE_ADRENO_KERNELS @@ -4421,7 +4491,7 @@ static const char * ggml_backend_opencl_name(ggml_backend_t backend) { } static void ggml_backend_opencl_free(ggml_backend_t backend) { - ggml_cl2_free(backend); + ggml_cl_free(backend); } static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -4460,14 +4530,17 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { // enqueued to it won't start until commands in the other devices have // completed. static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { - if (g_ggml_backend_opencl_devices.size() < 2) - return; // No other devices to synchronize with. + if (g_ggml_backend_opencl_devices.size() < 2) { + return; // No other devices to synchronize with. + } std::vector events; events.reserve(g_ggml_backend_opencl_devices.size()); for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { - auto * other_backend_ctx = ggml_cl2_init(&backend_dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) backend_dev.context; + auto * other_backend_ctx = dev_ctx->backend_ctx; + if (backend_ctx != other_backend_ctx) { cl_event ev; CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); @@ -4880,6 +4953,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_IM2COL: return true; case GGML_OP_ARGSORT: { + load_cl_kernels_argsort(backend_ctx); + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); @@ -4897,6 +4972,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: { + load_cl_kernels_flash_attn(backend_ctx); + const ggml_tensor * q = op->src[0]; const ggml_tensor * k = op->src[1]; const ggml_tensor * v = op->src[2]; @@ -4964,7 +5041,7 @@ static ggml_backend_i ggml_backend_opencl_i = { ggml_backend_t ggml_backend_opencl_init(void) { ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(dev); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_opencl_guid(), @@ -5343,15 +5420,13 @@ static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) } static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device); - return (void *) (uintptr_t) backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + return (void *) (uintptr_t) dev_ctx->backend_ctx->alignment; } static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_cl2_init(buffer->buft->device); - if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); @@ -5391,7 +5466,8 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff } static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -6626,7 +6702,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor->extra); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context *backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -7470,8 +7547,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, } static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_dev_t dev = buffer->buft->device; - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + cl_command_queue queue = backend_ctx->queue; ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; @@ -7511,7 +7589,8 @@ static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer } static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(buffer_type->device); + load_cl_kernels(backend_ctx); // clCreateBuffer returns -61 for size 0 size = std::max(size, (size_t)1); @@ -7534,15 +7613,15 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b } static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - return backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + return dev_ctx->backend_ctx->alignment; } static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { static size_t max_size = -1; if (max_size == (size_t)-1) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - max_size = backend_ctx->max_alloc_size; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + max_size = dev_ctx->backend_ctx->max_alloc_size; } return max_size; } @@ -7579,14 +7658,13 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; - ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) dev_ctx->backend_ctx; static const size_t opencl_extra_margin = 1024ull*1024ull*1024ull; // OpenCL does not provide reliable currently-free device memory. // Use total/global memory as a best-effort upper bound. // Improved safety: Reduce by a 1GiB extra margin for common --fit - *total = backend_ctx->global_mem_size; + *total = dev_ctx->global_mem_size; *free = *total > opencl_extra_margin ? *total - opencl_extra_margin : 0; } @@ -7610,7 +7688,7 @@ static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct } static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context * backend_ctx = ggml_cl_init(dev); // Getting a new reference to the backend, increase ref_count backend_ctx->ref_count++; @@ -7647,6 +7725,7 @@ static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_bac } static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + ggml_cl_init(dev); return ggml_opencl_supports_op(dev, op); } @@ -7659,8 +7738,8 @@ static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggm // Check cl_context is the same. clEnqueue* commands may not use // buffers from another cl_context. - ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); - ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); + ggml_backend_opencl_context * backend_ctx0 = ggml_cl_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl_init(buft->device); return backend_ctx0->context == backend_ctx1->context; } From 3d596aff1297d3b9e7f151c661a2464040cf8330 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Wed, 20 May 2026 22:14:13 -0700 Subject: [PATCH 35/55] hexagon: ssm-conv fix for large prompts (llama/23307) * hexagon: remove gathers and better handling of vtcm in ssm-conv * hexagon: relax ssm-conv gating requirements * hexagon: add new prefill ssm-conv backend test * hexagon: remove trailing white space * hex-rope: uninline rope_cache_init, otherwise it breaks after rebaseing with SSM_CONV changes --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 7 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 4 +- ggml/src/ggml-hexagon/htp/ssm-conv.c | 388 +++++++++++++++---------- 3 files changed, 246 insertions(+), 153 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 080fb7f47e3..9db99cb0f3a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2735,9 +2735,10 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { return false; } - - // TODO: add support for non-contiguous tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (src0->nb[0] != sizeof(float) || src1->nb[0] != sizeof(float) || dst->nb[0] != sizeof(float)) { + return false; + } + if (src0->nb[1] != src0->ne[0] * sizeof(float) || src1->nb[1] != src1->ne[0] * sizeof(float)) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 9901453e91e..b398e19f06e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -107,7 +107,7 @@ static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dim cache[i0 + 1] = sinf(theta_final) * mscale_final; } -static void rope_cache_init(const float theta_base, +static __attribute__((noinline)) void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, float * corr_dims, @@ -129,7 +129,7 @@ static void rope_cache_init(const float theta_base, // pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). // sections[4]: number of head dims assigned to each position component. -static void mrope_cache_init(const float pos_t, +static __attribute__((noinline)) void mrope_cache_init(const float pos_t, const float pos_h, const float pos_w, const float pos_e, diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index a28fd03e978..d574da2e2bc 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -20,55 +20,56 @@ #include "htp-ops.h" #include "hvx-utils.h" -#define htp_ssm_conv_tensors_preamble \ - const struct htp_tensor * restrict src0 = octx->src[0]; \ - const struct htp_tensor * restrict src1 = octx->src[1]; \ - const struct htp_tensor * restrict dst = octx->dst; \ - struct htp_spad * restrict src0_spad = &octx->src0_spad; \ - struct htp_spad * restrict src1_spad = &octx->src1_spad; \ - struct htp_spad * restrict dst_spad = &octx->dst_spad; \ - \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_ssm_conv_context { struct htp_ops_context * octx; uint32_t nrows_per_thread; + uint32_t d_inner_tile; uint64_t t_start; }; -#define htp_ssm_conv_preamble \ +#define htp_ssm_conv_preamble \ struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ - struct htp_ops_context * octx = scctx->octx; \ - htp_ssm_conv_tensors_preamble; \ - dma_queue * dma_queue = octx->ctx->dma[ith]; + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; // Scalar FP32 SSM_CONV implementation static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { @@ -128,118 +129,211 @@ static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *da dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension -static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { - htp_ssm_conv_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - const int nc = src1->ne[0]; // d_conv - const int ncs = src0->ne[0]; // d_conv - 1 + n_t +// In-register 32x32 fp32 transpose using std 5-stage HVX vshuff butterfly. +static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { + HVX_Vector tmp[32]; - const uint32_t d_conv = src1->ne[0]; - const uint32_t d_inner = src0->ne[1]; - const uint32_t n_t = dst->ne[1]; - const uint32_t n_s = dst->ne[2]; + // Stage 0 (R = -4): pair (2i, 2i+1) for i = 0..15. m -> tmp. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[2*i + 1], m[2*i], -4); + tmp[2*i + 0] = Q6_V_lo_W(p); + tmp[2*i + 1] = Q6_V_hi_W(p); + } - const float * src0_data = (const float *) src0->data; - const float * src1_data = (const float *) src1->data; - float * dst_data = (float *) dst->data; + // Stage 1 (R = -8): per block of 4, pair (b+0, b+2) and (b+1, b+3). tmp -> m. + for (int b = 0; b < 32; b += 4) { + HVX_VectorPair p0 = Q6_W_vshuff_VVR(tmp[b + 2], tmp[b + 0], -8); + HVX_VectorPair p1 = Q6_W_vshuff_VVR(tmp[b + 3], tmp[b + 1], -8); + m[b + 0] = Q6_V_lo_W(p0); m[b + 1] = Q6_V_hi_W(p0); + m[b + 2] = Q6_V_lo_W(p1); m[b + 3] = Q6_V_hi_W(p1); + } - // Calculate row range for this thread - const int dr = scctx->nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = MIN(ir0 + dr, d_inner); - const uint32_t ir = ir1 - ir0; + // Stage 2 (R = -16): per block of 8, pair (b+i, b+i+4) for i = 0..3. m -> tmp. + for (int b = 0; b < 32; b += 8) { + for (int i = 0; i < 4; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[b + i + 4], m[b + i], -16); + tmp[b + 2*i + 0] = Q6_V_lo_W(p); + tmp[b + 2*i + 1] = Q6_V_hi_W(p); + } + } - if (ir0 >= ir1) { - return; // No work for this thread + // Stage 3 (R = -32): per block of 16, pair (b+i, b+i+8) for i = 0..7. tmp -> m. + for (int b = 0; b < 32; b += 16) { + for (int i = 0; i < 8; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(tmp[b + i + 8], tmp[b + i], -32); + m[b + 2*i + 0] = Q6_V_lo_W(p); + m[b + 2*i + 1] = Q6_V_hi_W(p); + } } - // src0 and src1 gather offsets - uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; - uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; + // Stage 4 (R = -64): pair (i, i+16) for i = 0..15. m -> tmp -> m. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[i + 16], m[i], -64); + tmp[2 * i + 0] = Q6_V_lo_W(p); + tmp[2 * i + 1] = Q6_V_hi_W(p); + } - for (uint32_t i = 0; i < VLEN_FP32; ++i) { - src0_offsets[i] = i * (ncs) * sizeof(float); - src1_offsets[i] = i * (d_conv) * sizeof(float); + for (int i = 0; i < 32; ++i) { + m[i] = tmp[i]; } +} - const uint32_t src0_gather_len = VLEN * ncs; - const uint32_t src1_gather_len = VLEN * d_conv; +// HVX FP32 SSM_CONV implementation - channel-vectorized HVX kernel with src0/src1 +// transposed into VTCM. +// +// VTCM layouts (per thread): +// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). +// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// +// d_inner_tile is chosen so that per-thread VTCM stays under the budget. +// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. +#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread + +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +static inline void transpose_src1(const float * src1_data, + uint32_t src1_stride_inner, + uint32_t i1_off, + uint32_t d_inner_per_thread, + uint32_t d_conv, + float * src1_T) { + for (uint32_t i = 0; i < d_inner_per_thread; ++i) { + const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; + for (uint32_t j = 0; j < d_conv; ++j) { + src1_T[j * d_inner_per_thread + i] = src_row[j]; + } + } +} - // gather scratchpads - HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); - HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); +// HVX 32x32 src0 transpose: src0 {ncs, d_inner} (DDR) -> src0_T {d_inner_tile, ncs} (VTCM) +static inline void transpose_src0_block(const float * src0_block, + uint32_t ncs, + uint32_t cb_n, + uint32_t d_inner_tile, + float * src0_T_block_dst, + uint32_t cb /* dst column offset */) { + const uint32_t T_TILE = VLEN_FP32; + + HVX_Vector __attribute__((aligned(VLEN))) sub[32]; + + for (uint32_t t0 = 0; t0 < ncs; t0 += T_TILE) { + const uint32_t t_n = MIN(T_TILE, ncs - t0); + + // Load 32 rows (channels) of T_TILE samples; pad missing channels with zeros. + for (uint32_t r = 0; r < cb_n; ++r) { + const float * src_row = src0_block + r * ncs + t0; + if (t_n == T_TILE) { + sub[r] = *(const HVX_UVector *) src_row; + } else { + HVX_Vector v = hvx_vec_splat_f32(0.0f); + hvx_vec_store_u(&v, t_n * sizeof(float), hvx_vec_splat_f32(0.0f)); + + float __attribute__((aligned(VLEN))) tmp[VLEN_FP32] = { 0 }; + for (uint32_t k = 0; k < t_n; ++k) tmp[k] = src_row[k]; + v = *(const HVX_Vector *) tmp; + sub[r] = v; + } + } + for (uint32_t r = cb_n; r < T_TILE; ++r) { + sub[r] = hvx_vec_splat_f32(0.0f); + } - float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); - float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); + hvx_transpose_32x32_f32(sub); - uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; - uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + // Store transposed sub-tile to src0_T at offsets (t0 + j) * d_inner_tile + cb. + // Only write the valid t_n rows of the transposed result. + for (uint32_t r = 0; r < t_n; ++r) { + float * dst = src0_T_block_dst + (t0 + r) * d_inner_tile + cb; + if (cb_n == T_TILE) { + *(HVX_UVector *) dst = sub[r]; + } else { + hvx_vec_store_u(dst, cb_n * sizeof(float), sub[r]); + } + } + } +} - // copy src1 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; - // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); - for (uint32_t i3 = 0; i3 < n_s; ++i3) { - float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + const uint32_t ncs = src0->ne[0]; - // copy src0 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); - // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + const uint32_t dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); - dma_queue_flush(dma_queue); + if (ir0 >= ir1) { + return; + } - for (uint32_t i2 = 0; i2 < n_t; ++i2) { - float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_tile = scctx->d_inner_tile; - const uint32_t nvec = ir / VLEN_FP32; - const uint32_t nloe = ir % VLEN_FP32; - uint32_t i1 = 0; + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; - for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Per-thread VTCM regions. + float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); + float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); - } + const uint32_t C_TILE = VLEN_FP32; - *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); - i1 += VLEN_FP32; - } + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t tile_off = 0; tile_off < d_inner_per_thread; tile_off += d_inner_tile) { + const uint32_t tile_n = MIN(d_inner_tile, d_inner_per_thread - tile_off); - if (nloe) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Place src0 chunk into VTCM in {d_inner_tile, ncs} layout. + const float * src0_block = src0_data + i3 * src0_stride_seq + (ir0 + tile_off) * src0_stride_inner; - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + transpose_src0_block(src0_block + cb * src0_stride_inner, ncs, cb_n, d_inner_tile, src0_T, cb); + } - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + for (uint32_t t = 0; t < n_t; ++t) { + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + + HVX_Vector acc = hvx_vec_splat_f32(0.0f); + for (uint32_t j = 0; j < d_conv; ++j) { + HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); + } + HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); + + float * dst_ptr = dst_data + i3 * dst_stride_seq + t * dst_stride_token + (ir0 + tile_off + cb); + if (cb_n == C_TILE) { + *(HVX_UVector *) dst_ptr = res; + } else { + hvx_vec_store_u(dst_ptr, cb_n * sizeof(float), res); + } } - - hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) tile=%u * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, d_inner_tile, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } @@ -264,46 +358,44 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t use_hvx = 0; - if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { - int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && - hex_is_aligned((void *) src1->data, VLEN) && - hex_is_aligned((void *) dst->data, VLEN); - - if (is_aligned) { - use_hvx = 1; - } + if (d_inner >= VLEN_FP32 && n_t >= VLEN_FP32) { + use_hvx = 1; } - if (use_hvx) { - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); - octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); - octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); - octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); + const uint32_t d_inner_per_thread = scctx.nrows_per_thread; + const uint32_t ncs = src0->ne[0]; - octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; + const uint32_t src1_T_size = hex_round_up(d_conv * d_inner_per_thread * sizeof(float), 256); + const uint32_t src0_T_max = HTP_SSM_CONV_VTCM_BUDGET > src1_T_size ? HTP_SSM_CONV_VTCM_BUDGET - src1_T_size : 0; - // Compute gather scratchpad size for src0 and src1 - const size_t gather_spad_size = n_threads * VLEN * 2; + uint32_t d_inner_tile = (src0_T_max / sizeof(float)) / ncs; + d_inner_tile -= (d_inner_tile % VLEN_FP32); + if (d_inner_tile == 0) { + FARF(HIGH, "ssm_conv-f32: inner tile rounds to 0 (ncs=%u), falling back to scalar\n", ncs); + use_hvx = 0; + } else { + scctx.d_inner_tile = d_inner_tile; - octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; + octx->src0_spad.size_per_thread = hex_round_up(d_inner_tile * ncs * sizeof(float), 256); + octx->src1_spad.size_per_thread = src1_T_size; + octx->dst_spad.size_per_thread = 0; - FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", - gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, - octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, - octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = 0; - const size_t total_spad_size = - gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; - if (total_spad_size > octx->ctx->vtcm_size) { - FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, - octx->ctx->vtcm_size); + const size_t total_spad = octx->src0_spad.size + octx->src1_spad.size; + if (total_spad > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: scratchpad %zu exceeds VTCM %zu, falling back to scalar\n", + total_spad, octx->ctx->vtcm_size); use_hvx = 0; } } From 2b98710046566cde3856b736e75f485067a282b6 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Thu, 21 May 2026 06:24:40 +0000 Subject: [PATCH 36/55] ggml : Check the right iface method before using the fallback 2d get (llama/23306) Probably no backends implement only one of 2d get/set, but this might be annoying for some future backend developer trying to add 2d get/set. --- ggml/src/ggml-backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 4e36909f45e..5c0e5b1b9e2 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -379,7 +379,7 @@ void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf != NULL && "tensor buffer not set"); - if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + if (n_copies <= 1 || buf->iface.get_tensor_2d == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } From 10254e3b8500dc0c1cc4a6fbe1beafa98caefac7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 21 May 2026 13:34:08 +0300 Subject: [PATCH 37/55] metal : optimize concat kernel and fix set kernel threads (llama/23411) * metal : fix GGML_OP_SET kernel threads * tests : extend test_cpy to support different src/dst shapes Extend test_cpy to support different source and destination tensor shapes for CPY operations (reshaping), where the total number of elements must match. - Renamed ne -> ne_src, added ne_dst parameter (default: use src shape) - Added 50 new reshaping test cases covering 1D<->2D<->3D<->4D conversions - Tests exercise 1024 boundary, small shapes, and large dimensionality changes - Fixed dangling reference bug (storing & to temporary std::array) - Updated all existing test calls with permute/transpose args for compatibility Assisted-by: llama.cpp:local pi * metal : optimize concat kernel with row batching for small widths When ne0 < 256, batch multiple rows into a single threadgroup to improve occupancy. This avoids underutilizing the GPU when processing narrow tensors. - Dispatch nth = min(256, ne0) threads per group - Calculate nrptg (rows per threadgroup) to fill up to 256 threads - Update kernel index calculation to handle the row batching - Add boundary check for i1 >= ne1 Assisted-by: llama.cpp:local pi * tests : clean-up * tests : refactor CPY shape tests to use dimension permutations Replace 75 hardcoded test cases with a loop over permutations of {3, 5, 7, 32} (total elements: 3360). Each src permutation is tested against canonical sorted and reverse dst, skipping identical shapes. Covers F32, F16, and Q4_0 (when both src and dst ne0 == 32). Assisted-by: llama.cpp:local pi --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 19 +++++++++++++++---- ggml/src/ggml-metal/ggml-metal.metal | 6 +++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 8506000b6c0..206af227a2c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -564,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(1024, ne0); + int nth = std::min(256, ne0); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + if (nth < 256) { + nrptg = std::min((256 + nth - 1) / nth, ne1); + if (nrptg * nth > 256) { + nrptg = 256 / nth; + } + } + + const int nw0 = (ne1 + nrptg - 1) / nrptg; + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1); return 1; } @@ -1786,7 +1797,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nk0 = ne10/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne11, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1797,7 +1808,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4cf9dbea946..e772664ba91 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7486,7 +7486,11 @@ kernel void kernel_concat( const int i3 = tgpig.z; const int i2 = tgpig.y; - const int i1 = tgpig.x; + const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y; + + if (i1 >= args.ne1) { + return; + } int o[4] = {0, 0, 0, 0}; o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); From 0e74cab17b43dc6869f49db3d41095d985cea72c Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Thu, 21 May 2026 10:58:49 -0400 Subject: [PATCH 38/55] fix(flash-attn): replace f32 with kv_type and q_type (llama/23372) --- .../wgsl-shaders/flash_attn_tile.wgsl | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index ae8036b9ac5..4133f0ab564 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -122,9 +122,9 @@ const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; -var q_shmem: array; -var kv_shmem: array; -var p_shmem: array; +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @@ -169,10 +169,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let head = f32(head_idx); let slope = select(1.0, - select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), - pow(params.m0, head + 1.0), - head < params.n_head_log2), - params.max_bias > 0.0); + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { let q_tile_row = elem_idx / HEAD_DIM_QK; @@ -181,7 +181,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; q_shmem[elem_idx] = select( 0.0, - f32(Q[global_q_row_offset + q_col]) * params.scale, + Q_TYPE(Q[global_q_row_offset + q_col]) * params.scale, head_q_row < params.seq_len_q); } @@ -213,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(k4.x); - kv_shmem[kv_off + 1u] = f32(k4.y); - kv_shmem[kv_off + 2u] = f32(k4.z); - kv_shmem[kv_off + 3u] = f32(k4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); } workgroupBarrier(); @@ -233,18 +233,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var dot_val = 0.0; for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { let q_off = q_base + chunk * 4u; - let qv = vec4( + let qv = vec4( q_shmem[q_off + 0u], q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4( + let kv = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - dot_val += dot(qv, kv); + dot_val += dot(vec4(qv), vec4(kv)); } #ifdef LOGIT_SOFTCAP dot_val = params.logit_softcap * tanh(dot_val); @@ -271,7 +271,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = p; + p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); local_sum += p; } } @@ -285,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(v4.x); - kv_shmem[kv_off + 1u] = f32(v4.y); - kv_shmem[kv_off + 2u] = f32(v4.z); - kv_shmem[kv_off + 3u] = f32(v4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); } workgroupBarrier(); @@ -308,12 +308,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { let p = p_shmem[subgroup_p_offset + kv_local]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4( + let v4 = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += p * v4; + acc += f32(p) * vec4(v4); } out_regs[reg_idx] = acc; } From 9c206f7a2000ce6c4ce44bee954d9fb5a43608ba Mon Sep 17 00:00:00 2001 From: Pascal Date: Thu, 21 May 2026 19:39:42 +0200 Subject: [PATCH 39/55] vulkan: fuse snake activation (mul, sin, sqr, mul, add) (llama/22855) * vulkan: fuse snake activation (mul, sin, sqr, mul, add) Add snake.comp shader with F32 / F16 / BF16 pipelines and ggml_vk_snake_dispatch_fused. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. test_snake_fuse from the CUDA PR now also compares CPU naive vs Vulkan fused across F32 / F16 / BF16. * vulkan: address jeffbolznv review for fused snake activation Rename T / C to ne0 / ne1 in the shader and push constants to match the standard naming convention used across the Vulkan backend. Tighten ggml_vk_can_fuse_snake: require x and dst to be contiguous (the shader uses idx = i0 + i1 * ne0) and require a / inv_b to be tightly packed on the broadcast dim (the shader reads data_a[i1]). * vulkan: tighten snake fusion type checks for all operands (address jeffbolznv review) * vulkan: reject snake fusion when ne[2] or ne[3] > 1 (address jeffbolznv review) * vulkan: address 0cc4m review for fused snake activation snake.comp is renamed to follow the ggml DATA_A_* / A_TYPE convention. A_TYPE now applies to the activation tensor data_a instead of the broadcast multiplier, and the bindings become data_a (A_TYPE), data_b (float), data_c (float) and data_d (D_TYPE). A header at the top of the shader maps each buffer to its role in y = x + sin(b * x)^2 * c. On the C++ side, ggml_vk_can_fuse_snake reuses the existing snake_pattern constant instead of duplicating the op list, sin_node is extracted as a named local alongside the other chain nodes, and the broadcast operands a and inv_b are now required to be GGML_TYPE_F32 to match the hardcoded float bindings on data_b and data_c (the previous a->type == x->type would silently reject any future BF16 or F16 chain once the supports_op gate for SIN / SQR is lifted). ggml_vk_snake_dispatch_fused gets an explicit GGML_TYPE_F32 case and GGML_ABORT on default in place of the silent f32 fallback, and a stale comment about data_a[i1] / data_inv_b[i1] is refreshed to match the new binding names. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 136 +++++++++++++++++- .../src/ggml-vulkan/vulkan-shaders/snake.comp | 49 +++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 + 3 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/snake.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d3fb19048d9..aa289220a90 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -499,6 +499,12 @@ static constexpr std::initializer_list topk_moe_late_softmax { GGM GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder +// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion. +static constexpr std::initializer_list snake_pattern { GGML_OP_MUL, GGML_OP_SIN, + GGML_OP_SQR, GGML_OP_MUL, + GGML_OP_ADD }; + //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] @@ -846,6 +852,9 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_snake_f32; + vk_pipeline pipeline_snake_f16; + vk_pipeline pipeline_snake_bf16; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; @@ -1475,6 +1484,11 @@ struct vk_op_conv_transpose_1d_push_constants { int32_t s0; }; +struct vk_op_snake_push_constants { + uint32_t ne0; + uint32_t ne1; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -4845,6 +4859,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -12110,6 +12128,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p)); } +// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b. +// Match the naive mul -> sin -> sqr -> mul -> add chain and run the +// dedicated kernel directly. The pattern is validated by +// ggml_vk_can_fuse_snake before this call. +static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + vk_pipeline pipeline = nullptr; + switch (x->type) { + case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break; + case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break; + case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break; + default: GGML_ABORT("unsupported type"); + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x); + vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a); + vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add); + + vk_op_snake_push_constants pc{}; + pc.ne0 = static_cast(x->ne[0]); + pc.ne1 = static_cast(x->ne[1]); + + std::array elements = { pc.ne0, pc.ne1, 1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -13318,7 +13375,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_MUL: - ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + if (ctx->num_additional_fused_ops) { + ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx); + } else { + ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + } break; case GGML_OP_DIV: @@ -14691,6 +14752,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add. +// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that +// the broadcast operands a and inv_b share a [1, C] layout. +static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + GGML_UNUSED(ctx); + if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) { + return false; + } + + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + const ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + if (x_in_add != x) { + return false; + } + if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) { + return false; + } + // Shader bindings: data_a is A_TYPE so it follows x's precision, while + // data_b and data_c are hardcoded float, so the broadcast operands must + // be F32 regardless of x's type. + if (a->type != GGML_TYPE_F32) return false; + if (inv_b->type != GGML_TYPE_F32) return false; + // Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline). + if (mul0->type != x->type) return false; + if (sin_node->type != x->type) return false; + if (sqr->type != x->type) return false; + if (mul1->type != x->type) return false; + if (add->type != x->type) return false; + if (!ggml_are_same_shape(a, inv_b)) { + return false; + } + if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) { + return false; + } + // Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b + // must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader. + if (x->ne[2] != 1 || x->ne[3] != 1) return false; + if (add->ne[2] != 1 || add->ne[3] != 1) return false; + if (a->ne[2] != 1 || a->ne[3] != 1) return false; + if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false; + // Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1], + // so every operand must be contiguous. + if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) || + !ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) { + return false; + } + return true; +} + // Check whether the tensors overlap in memory. // Fusions can potentially overwrite src tensors in ways that are not prevented // by ggml-alloc. If the fusion src is being applied in a way that's elementwise @@ -14998,6 +15118,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg op_srcs_fused_elementwise[0] = false; op_srcs_fused_elementwise[1] = false; op_srcs_fused_elementwise[2] = false; + } else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 4; + fusion_string = "SNAKE"; + // elementwise=true: snake.comp is safe under exact aliasing because each + // thread reads data_x[idx] into a register before writing data_d[idx] + // with a data dependency on that register. The overlap check still + // rejects partial overlaps (different base or size). + std::fill_n(op_srcs_fused_elementwise, 5, true); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -15288,6 +15416,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_late_softmax)) { continue; } + if (keep_pattern(snake_pattern)) { + continue; + } // First, grab the next unused node. current_set.push_back(first_unused); @@ -15310,7 +15441,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (match_pattern(topk_moe_early_softmax_norm, j) || match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || - match_pattern(topk_moe_late_softmax, j)) { + match_pattern(topk_moe_late_softmax, j) || + match_pattern(snake_pattern, j)) { continue; } bool ok = true; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp new file mode 100644 index 00000000000..8585538cbb0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +// Fused snake activation: y = x + sin(b * x)^2 * c +// data_a [ne0, ne1] per element activation x (A_TYPE) +// data_b [1, ne1] per channel multiplier (float) +// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq) +// data_d [ne0, ne1] output y (D_TYPE) +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {float data_b[];}; +layout (binding = 2) readonly buffer C {float data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t ne0; + uint32_t ne1; +} p; + +// Load A_TYPE to float +float load_val(uint32_t idx) { +#if defined(DATA_A_BF16) + return bf16_to_fp32(uint32_t(data_a[idx])); +#else + return float(data_a[idx]); +#endif +} + +// Store float as D_TYPE +void store_val(uint32_t idx, float v) { +#if defined(DATA_D_BF16) + data_d[idx] = D_TYPE(fp32_to_bf16(v)); +#else + data_d[idx] = D_TYPE(v); +#endif +} + +void main() { + const uint32_t i0 = gl_GlobalInvocationID.x; + const uint32_t i1 = gl_GlobalInvocationID.y; + if (i0 >= p.ne0 || i1 >= p.ne1) return; + + const uint32_t idx = i0 + i1 * p.ne0; + const float xi = load_val(idx); + const float s = sin(data_b[i1] * xi); + store_val(idx, xi + s * s * data_c[i1]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e3a9d61a558..a1d735150fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -952,6 +952,10 @@ void process_shaders() { string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); From ad494e3dc8a7bea9ab2d08e13d9cd5de1ff87083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 21 May 2026 23:35:29 +0200 Subject: [PATCH 40/55] CUDA: fix PDL CC check for JIT compilation (llama/23471) --- ggml/src/ggml-cuda/common.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9c73fe7e6fa..e54ecb29308 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1561,7 +1561,8 @@ static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_ke return env == nullptr || std::atoi(env) != 0; }(); - if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) { auto pdl_cfg = ggml_cuda_pdl_config(launch_params); CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward(args)... )); From f86ab6f28244181f253f12c8f0c4a3c14fe89e8b Mon Sep 17 00:00:00 2001 From: Sachin Sharma Date: Fri, 22 May 2026 16:46:55 +0530 Subject: [PATCH 41/55] ggml-zendnn : add Q8_0 quantization support (llama/23414) * ggml-zendnn : add Q8_0 quantization support * ggml-zendnn : sync with latest ZenDNN * ggml-zendnn : address review comments for Q8_0 --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 56 ++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index f1e4f991fae..e4ba9cfbd0f 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17 + GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6a83bb6b1ec..6051d082003 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -2,6 +2,10 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + #include "zendnnl.hpp" #include @@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() { return zendnnl::common::data_type_t::f32; } else if constexpr (std::is_same_v) { return zendnnl::common::data_type_t::bf16; + } else if constexpr (std::is_same_v) { + return zendnnl::common::data_type_t::s8; } else { return zendnnl::common::data_type_t::none; } @@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params.num_threads = ctx->n_threads; zendnnl::lowoha::matmul::matmul_batch_params_t batch_params; + + if constexpr (std::is_same_v) { + params.dtypes.compute = zendnnl::common::data_type_t::s8; + const int64_t num_groups = k / QK8_0; + params.dynamic_quant = true; + params.quant_params.src_scale.buff = nullptr; + params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16; + params.quant_params.src_scale.dims = {n, num_groups}; + params.packing.pack_format_b = 1; + } + zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C @@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6 (const ggml_bf16_t *)B, ldb, (float *)C, ldc); return false; + case GGML_TYPE_Q8_0: + if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32) + return false; + return ggml_zendnn_matmul( + ctx, m, n, k, + (const block_q8_0 *)A, lda, + (const float *)B, ldb, + (float *)C, ldc); default: return false; // unsupported type } @@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat( const int64_t r3 = ne13/ne03; void * work_data = ctx->work_data.get(); - if (src1->type != vec_dot_type) { + + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; @@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - const void* wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); if (!ggml_zendnn_sgemm(ctx, ne01, // m @@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat( static_cast(dst->data) + i12*nb2 + i13*nb3, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const size_t nbw1 = row_size; const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; - const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0; + + // For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes), + // not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x. + const size_t f32_row_size = (size_t)ne10 * sizeof(float); + const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size; // size for MoE gather/scatter buffers - const size_t wdata_cur_size = max_rows * row_size; + const size_t wdata_cur_size = max_rows * gather_row_size; const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); // allocate single buffer for all needs @@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id( char * wdata_cur = work_data + src1_conv_size; char * dst_cur = wdata_cur + wdata_cur_size; - if (src1->type != vec_dot_type) { + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { GGML_ASSERT(src1->type == GGML_TYPE_F32); #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) @@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } } - const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; // process each expert with gather -> gemm -> scatter pattern for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { @@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const int64_t i12 = row_mapping.i2; std::memcpy( - wdata_cur + ir1 * row_size, - (const char *) wdata + (i11 + i12*ne11) * row_size, - row_size + wdata_cur + ir1 * gather_row_size, + (const char *) wdata + (i11 + i12*ne11) * gather_row_size, + gather_row_size ); } @@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( dst_cur, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: return true; default: return false; From e3988d4f5fdfbe9b0ebbaf7e305ce9de7d2c1395 Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Fri, 22 May 2026 08:48:24 -0400 Subject: [PATCH 42/55] SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc) (llama/21580) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes #20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 --------- Co-authored-by: Claude Sonnet 4.6 --- ggml/src/ggml-sycl/dmmv.cpp | 47 +++++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 8 +++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5577bf73b28..4ae431a962e 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,6 +3,13 @@ #include "dequantize.hpp" #include "presets.hpp" +#if defined(__INTEL_LLVM_COMPILER) + #if __has_include() + #include + #define GGML_SYCL_DMMV_HAS_BF16 + #endif +#endif + static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -11,6 +18,16 @@ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat v.y() = x[ib + iqs + 1]; } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const sycl::ext::oneapi::bfloat16 *x = (const sycl::ext::oneapi::bfloat16 *)vx; + + // automatic bfloat16 -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} +#endif + static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; @@ -217,6 +234,28 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, } } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_mul_mat_vec_bf16_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + // The qk=1 kernel iterates with stride 2*GGML_SYCL_DMMV_X, so ncols must be a + // multiple of that — not just GGML_SYCL_DMMV_X — to avoid out-of-bounds reads. + GGML_ASSERT(ncols % (2*GGML_SYCL_DMMV_X) == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_bf16>(vx, y, dst, ncols, + nrows, item_ct1); + }); + } +} +#endif + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -1497,7 +1536,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_BF16; if (src1_convert_f16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, @@ -1565,6 +1605,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; +#ifdef GGML_SYCL_DMMV_HAS_BF16 + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; +#endif default: printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ea47f7153a..bba37a6f884 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3455,6 +3455,7 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_F16: + case GGML_TYPE_BF16: return true; default: return false; @@ -3818,8 +3819,13 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // The F16/BF16 qk=1 kernel iterates with stride 2*DMMV_X, requiring ne[0] to be + // a multiple of 2*DMMV_X. Quantized types use block-structured kernels that only + // need ne[0] % DMMV_X == 0. + const int64_t dmmv_x_required = (src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F16) ? + 2*GGML_SYCL_DMMV_X : GGML_SYCL_DMMV_X; return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + src0->ne[0] % dmmv_x_required == 0 && src1->ne[1] == 1; } static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From 4e99dde77a71b672ecfd54e6094f922e37e64b94 Mon Sep 17 00:00:00 2001 From: karavayev <192749314+karavayev@users.noreply.github.com> Date: Fri, 22 May 2026 08:48:56 -0400 Subject: [PATCH 43/55] SYCL : gated_delta_net K>1 (llama/23174) * sycl_gated_delta_net K>1 * editor_config --- ggml/src/ggml-sycl/gated_delta_net.cpp | 91 +++++++++++++++++++------- 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index ebc587524bf..9c2449aba0c 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -6,7 +6,7 @@ #include -template +template void gated_delta_net_sycl(const float * q, const float * k, const float * v, @@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q, int64_t sb3, const sycl::uint3 neqk1_magic, const sycl::uint3 rq3_magic, - float scale) { + float scale, + int K) { auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const uint32_t h_idx = item_ct1.get_group(2); const uint32_t sequence = item_ct1.get_group(1); @@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v; @@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[col * S_v + i]; + s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q, } attn_data += S_v * H; - } + // Write state back to global memory + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } + } + + if constexpr (!keep_rs_t) { +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net(const float * q_d, const float * k_d, const float * v_d, @@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d, int64_t neqk1, int64_t rq3, float scale, + int K, dpct::queue_ptr stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size; @@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 16; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 32; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 64; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 128; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } From 4dbad751a61772f29da19dbdd00aac5a4be20f43 Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Fri, 22 May 2026 21:49:45 +0900 Subject: [PATCH 44/55] sycl : Level Zero detection in ggml_sycl_init (llama/23097) * [SYCL] Centralize Level Zero detection in ggml_sycl_init * use the same wording * get back the warning --- ggml/src/ggml-sycl/common.hpp | 2 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 26 ++++++++------------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 96bc1c98bd9..6d19538215e 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -238,6 +238,8 @@ struct ggml_sycl_device_info { std::array default_tensor_split = {}; int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; + + bool ext_oneapi_level_zero = true; // sycl::backend::ext_oneapi_level_zero used by all enumerated GPU devices }; const ggml_sycl_device_info & ggml_sycl_info(); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index bba37a6f884..46795f43602 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -98,7 +98,7 @@ static ggml_sycl_device_info ggml_sycl_init() { for (int i = 0; i < info.device_count; ++i) { info.devices[i].vmm = 0; dpct::device_info prop; - sycl::device device = dpct::dev_mgr::instance().get_device(i); + auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); @@ -117,6 +117,12 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); info.devices[i].hw_info = get_device_hw_info(&device); + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + if (device.is_gpu() && device.default_queue().get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + info.ext_oneapi_level_zero = false; + } } for (int id = 0; id < info.device_count; ++id) { @@ -230,26 +236,10 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); #ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO - g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1); + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); #else g_ggml_sycl_enable_level_zero = 0; #endif - if (g_ggml_sycl_enable_level_zero) { - // Verify all GPU devices use the Level Zero backend before enabling L0 APIs. - // Only check GPU devices; CPU devices use OpenCL and would otherwise - // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. - for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); i++) { - auto & q = dpct::dev_mgr::instance().get_device(i).default_queue(); - if (!q.get_device().is_gpu()) { - continue; - } - if (q.get_backend() != sycl::backend::ext_oneapi_level_zero) { - GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); - g_ggml_sycl_enable_level_zero = 0; - break; - } - } - } #ifdef SYCL_FLASH_ATTN g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); From 7084cf00ecbcda9e0d802a07170927d20e87a828 Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Fri, 22 May 2026 21:50:17 +0900 Subject: [PATCH 45/55] SYCL: improve MoE prefill throughput (llama/23142) - change `k_copy_src1_to_contiguous` so that uses a precomputed contiguous mapping where all rows "owned" by an expert are in one slice with a know starts and ends - switch the `O(n_as * n_routed_rows)` contraption to a counting sort-based procedure with `O(n_as + n_routed_rows)` complexity --- ggml/src/ggml-sycl/ggml-sycl.cpp | 195 +++++++++++++++++-------------- 1 file changed, 105 insertions(+), 90 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 46795f43602..b3fbb621196 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3919,35 +3919,17 @@ struct mmid_row_mapping { __dpct_inline__ static void k_copy_src1_to_contiguous( const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, - int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, - const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, - const sycl::nd_item<3> &item_ct1, int &src1_row) { - int32_t iid1 = item_ct1.get_group(2); - int32_t id = item_ct1.get_group(1); - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + const sycl::nd_item<3> &item_ct1) { + const int32_t src1_row = item_ct1.get_group(2); - if (row_id_i != i02) { - return; - } + const int32_t iid1 = row_mapping[src1_row].i2; + const int32_t id = row_mapping[src1_row].i1; const int64_t i11 = id % ne11; const int64_t i12 = iid1; - if (item_ct1.get_local_id(2) == 0) { - src1_row = - dpct::atomic_fetch_add( - cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - /* - DPCT1065:194: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. - */ - item_ct1.barrier(); - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); @@ -4022,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused( src1_row_stride, stream); } +// counting sort of the routed rows by expert id (row_id_i, as chosen by the router): +// builds a projection of a memory layout where each expert's slice is contiguous +static void mmid_counting_sort_rows( + const ggml_tensor * ids, const char * ids_host, + int64_t n_ids, int64_t n_as, int64_t n_routed_rows, + std::vector & expert_counts, + std::vector & expert_row_offsets, + std::vector & routed_row_src) { + + // frequencies: how many routed rows each expert "owns" + expert_counts.assign(n_as, 0); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + expert_counts[row_id_i]++; + } + } + + // where each expert's slice starts (row indices) and the previous ends + expert_row_offsets.assign(n_as + 1, 0); + for (int64_t i02 = 0; i02 < n_as; i02++) { + expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02]; + } + + std::vector expert_row_next = expert_row_offsets; + routed_row_src.resize(n_routed_rows); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + // find and validate the next free row for a given expert (row_id_i) + const int64_t routed_row = expert_row_next[row_id_i]++; + GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]); + GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]); + routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1}; + } + } +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -4100,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + // how many "owned" routed rows to pass to each expert + std::vector expert_row_counts; + // where each expert's slice starts and the previous ends (row indices, right-exclusive) + std::vector expert_row_offsets; + // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous + std::vector routed_row_src; - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows, + expert_row_counts, expert_row_offsets, routed_row_src); - if (row_id_i != i02) { - continue; - } + ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), n_routed_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping)))); - num_src1_rows++; - } - } + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_row_mapping_get, + ne11, ne10, nb11, nb12, + item_ct1); + }); + }); + } + + for (int64_t i02 = 0; i02 < n_as; i02++) { + const int64_t num_src1_rows = expert_row_counts[i02]; if (num_src1_rows == 0) { continue; } - - ggml_sycl_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); - - const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); - - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); - sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor src1_row_acc(cgh); - - char *__restrict src1_contiguous_get = - src1_contiguous.get(); - int *__restrict dev_cur_src1_row_get = - dev_cur_src1_row.get(); - mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - size_t ids_nb_ct6 = ids->nb[1]; - size_t ids_nb_ct7 = ids->nb[0]; - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_src1_to_contiguous( - src1_original, src1_contiguous_get, - dev_cur_src1_row_get, - dev_row_mapping_get, ids_dev, i02, - ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, - item_ct1, src1_row_acc); - }); - }); - } + const int64_t expert_row_offset = expert_row_offsets[i02]; src0_row.data = src0_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.data = src1_contiguous.get() + expert_row_offset*nb11; src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.data = dst_contiguous.get() + expert_row_offset*nb1; dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); - sycl::range<3> grid_dims(1, 1, num_src1_rows); - stream->submit([&](sycl::handler &cgh) { - const char *__restrict dst_contiguous_get = - dst_contiguous.get(); - const mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_dst_from_contiguous(dst_original, - dst_contiguous_get, - dev_row_mapping_get, - ne0, nb1, nb2, item_ct1); - }); - }); - } + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); } } } From 594775350d89caa8ce8ed4e24045acdd1c1c090c Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Fri, 22 May 2026 17:08:41 -0700 Subject: [PATCH 46/55] opencl: generalize Adreno MoE kernels on M (llama/23449) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 18 +++--- ggml/src/ggml-opencl/kernels/cvt.cl | 64 +++++++++++++++++++ .../kernels/gemm_moe_mxfp4_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_0_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_1_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_k_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_0_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_1_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_k_f32_ns.cl | 6 +- .../kernels/gemm_moe_q6_k_f32_ns.cl | 6 +- .../kernels/gemv_moe_mxfp4_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_0_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_1_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_k_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_0_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_1_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_k_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q6_k_f32_ns.cl | 4 ++ 18 files changed, 145 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 5fc46f789ec..ea0b44feea2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4693,7 +4693,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { GGML_UNUSED(backend_ctx); int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 32 == 0); } inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { @@ -14297,7 +14297,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14513,7 +14513,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14689,7 +14689,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14865,7 +14865,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15118,7 +15118,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15291,7 +15291,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15469,7 +15469,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15644,7 +15644,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 312366984b6..c25eabdd72b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -220,6 +220,10 @@ kernel void kernel_convert_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -263,6 +267,10 @@ kernel void kernel_restore_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -401,6 +409,10 @@ kernel void kernel_convert_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -446,6 +458,10 @@ kernel void kernel_restore_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -491,6 +507,10 @@ kernel void kernel_convert_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -536,6 +556,10 @@ kernel void kernel_restore_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -583,6 +607,10 @@ kernel void kernel_convert_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -630,6 +658,10 @@ kernel void kernel_restore_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -679,6 +711,10 @@ kernel void kernel_convert_block_q4_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -732,6 +768,10 @@ kernel void kernel_restore_block_q4_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -784,6 +824,10 @@ kernel void kernel_convert_block_q5_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -850,6 +894,10 @@ kernel void kernel_restore_block_q5_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -916,6 +964,10 @@ kernel void kernel_convert_block_q6_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; @@ -993,6 +1045,10 @@ kernel void kernel_restore_block_q6_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -1147,6 +1203,10 @@ kernel void kernel_convert_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -1190,6 +1250,10 @@ kernel void kernel_restore_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl index e404f392bdd..02cdbdd9fb1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -163,7 +163,7 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -248,6 +248,10 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl index 02290c17eb1..d403ed0cab1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl @@ -115,7 +115,7 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -198,6 +198,10 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl index e2574ae0187..b2bddf3f73a 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -200,6 +200,10 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl index 9d24aff6a20..ab8228d18ca 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -133,7 +133,7 @@ kernel void kernel_gemm_moe_q4_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -225,6 +225,10 @@ kernel void kernel_gemm_moe_q4_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl index 3524cb1bdbd..d1a35d58bb2 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -202,6 +202,10 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl index 5fc2a523234..90d345ecf51 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl @@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -204,6 +204,10 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl index 808a0c7db6a..13c26f6f3b6 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -134,7 +134,7 @@ kernel void kernel_gemm_moe_q5_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -230,6 +230,10 @@ kernel void kernel_gemm_moe_q5_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl index a040335adfa..85ccebec78c 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q6_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -209,6 +209,10 @@ kernel void kernel_gemm_moe_q6_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl index e4b44c1a56a..75129e20c65 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -82,6 +82,10 @@ __kernel void kernel_gemv_moe_mxfp4_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl index 6f4d3f53216..2d28db63ec5 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl @@ -37,6 +37,10 @@ __kernel void kernel_gemv_moe_q4_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl index 3739a215705..b98bdc0f12e 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q4_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl index 13d79f2526f..12464e9826e 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -54,6 +54,10 @@ __kernel void kernel_gemv_moe_q4_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl index 938054cf982..b43613638a8 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q5_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl index f33a4ef2757..7a666006e68 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl @@ -39,6 +39,10 @@ __kernel void kernel_gemv_moe_q5_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl index f128d44340a..7d868d7abd9 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -55,6 +55,10 @@ __kernel void kernel_gemv_moe_q5_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl index 526e609dc3a..c166bad5ba5 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q6_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; From 2282c7f5dffaa260c4c726d32318666e1828138e Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 23 May 2026 02:44:46 -0500 Subject: [PATCH 47/55] vulkan: fix windows find_package of SPIRV-Headers (llama/23215) * vulkan: fix windows find_package of SPIRV-Headers * not windows-only --- ggml/src/ggml-vulkan/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 6dbcea065b3..65785ae4566 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,7 +8,10 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) -find_package(SPIRV-Headers REQUIRED) +if (DEFINED ENV{VULKAN_SDK}) + list(APPEND CMAKE_PREFIX_PATH "$ENV{VULKAN_SDK}") +endif() +find_package(SPIRV-Headers CONFIG REQUIRED) if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files From 945fb5f23d529eb16c4d8e6be2238e822643d151 Mon Sep 17 00:00:00 2001 From: dskwe Date: Sat, 23 May 2026 18:49:24 +0800 Subject: [PATCH 48/55] ggml : Check the right iface method before using the fallback 2d get (llama/23514) --- ggml/src/ggml-backend.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 5c0e5b1b9e2..87615921c09 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -306,7 +306,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + if (n_copies <= 1 || backend->iface.get_tensor_2d_async == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } @@ -317,7 +317,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ } GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); } From 89cb85fb3f6763fa9915bc4f6a985c7291e2d5ec Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Sat, 23 May 2026 19:56:59 -0700 Subject: [PATCH 49/55] hexagon: apply repl optimization in flash attn softmax as #22993 (llama/23455) --- ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 4a4ff0b331d..9e1b778b01f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -852,9 +852,10 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. - // vror brings the target lane to lane 0, then extract + re-splat. - HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); - HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + // vror brings the target lane to lane 0, then vdelta replicates it + // across all lanes — stays in the vector domain (no store/reload). + HVX_Vector v_m_prev0 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)); + HVX_Vector v_m_prev1 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)); // HVX max — both operands are splats, so result is splat of m_new. HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); From 56ee08620afdbadc623be85452cf0134c4009c68 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Sat, 23 May 2026 23:11:43 -0700 Subject: [PATCH 50/55] opencl: batch profiling to improve speed and prevent memory leaks (llama/23495) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 36 +++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ea0b44feea2..42286435bc6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -661,11 +661,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector profiling_info; + std::vector profiling_results; - void write_profiling_info() { - FILE * fperf = fopen("cl_profiling.csv", "w"); - if (!fperf) { - GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + void flush_profiling_batch() { + if (profiling_info.empty()) { return; } @@ -689,6 +688,7 @@ struct ggml_backend_opencl_context { CL_CHECK(clGetEventProfilingInfo( info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL)); CL_CHECK(clReleaseEvent(info.evt)); + info.evt = nullptr; char kernel_name[512]; CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME, @@ -706,10 +706,26 @@ struct ggml_backend_opencl_context { info.cmd_complete_duration_ns = cmd_complete - cmd_end; info.cmd_total_duration_ns = cmd_complete - cmd_queued; } + profiling_results.insert(profiling_results.end(), + std::make_move_iterator(profiling_info.begin()), + std::make_move_iterator(profiling_info.end())); + profiling_info.clear(); + } + + void write_profiling_info() { + if (profiling_results.empty()) { + return; + } // Dump a csv + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", info.op_name.c_str(), info.kernel_name.c_str(), info.cmd_duration_ns/1.e6f, @@ -720,14 +736,14 @@ struct ggml_backend_opencl_context { fclose(fperf); // Dump a simple chrome trace - FILE* ftrace = fopen("cl_trace.json", "w"); + FILE * ftrace = fopen("cl_trace.json", "w"); if (!ftrace) { GGML_LOG_ERROR("Failed to open cl_trace.json\n"); return; } fprintf(ftrace, "[\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_queued/1000); fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", @@ -738,6 +754,7 @@ struct ggml_backend_opencl_context { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_end/1000); } + fprintf(ftrace, "]\n"); fclose(ftrace); } @@ -758,6 +775,9 @@ struct ggml_backend_opencl_context { profiling_info.emplace_back(); populateProfilingInfo(profiling_info.back(), evt, kernel, work_dim, global_work_size, local_work_size, tensor); + if (profiling_info.size() >= 2048) { + flush_profiling_batch(); + } #else GGML_UNUSED(tensor); CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL)); @@ -804,7 +824,7 @@ struct ggml_backend_opencl_context { if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING write_profiling_info(); - profiling_info.clear(); + profiling_results.clear(); #endif } } From 2bb19cb4df4fb260bb8b418a4b92c6d38a004656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 24 May 2026 08:19:33 +0200 Subject: [PATCH 51/55] TP: fix entirely zero-sized slices per device (llama/23525) --- ggml/include/ggml-alloc.h | 1 + ggml/src/ggml-backend-meta.cpp | 36 ++++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 78aa059dde3..a7926a21a9a 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -76,6 +76,7 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i // Utils // Create a buffer and allocate all the tensors in a ggml_context // ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft +// ggml_backend_alloc_ctx_tensors_from_buft returns NULL on failure or if all tensors in ctx are already allocated or zero-sized GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index df0f405ed9f..5f9ae9c1bc5 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1275,6 +1275,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1382,6 +1385,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co for (size_t j = 0; j < n_bufs; j++){ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1445,6 +1451,7 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac buf_ctx->buf_configs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + GGML_ASSERT(simple_buf != nullptr); max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); } @@ -1474,8 +1481,27 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc t->data = (void *) 0x2000000000000000; // FIXME } for (size_t i = 0; i < n_simple_bufts; i++) { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( - meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + ggml_context * ctx = meta_buf_ctx->buf_configs[i].ctx; + ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); + + // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. + // For those edge cases, allocate a dummy buffer instead. + bool any_nonzero_slice = false; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (ggml_nelements(t) != 0) { + any_nonzero_slice = true; + break; + } + } + if (any_nonzero_slice) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft); + } else { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_buft_alloc_buffer(simple_buft, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf_ctx->buf_configs[i].buf; + } + } + GGML_ASSERT(meta_buf_ctx->buf_configs[i].buf != nullptr); meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); } return meta_buf; @@ -1605,6 +1631,9 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1646,6 +1675,9 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; From a16e642ad33d2efc423a87db4c2058ed06cb6de3 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 25 May 2026 02:15:46 -0500 Subject: [PATCH 52/55] ggml : Parallelize quant LUT init (llama/23595) - Use OpenMP to parallelize iq2xs_init_impl and iq3xs_init_impl. - Move the OpenMP detection from ggml-cpu to ggml-base. - Update OpenMP dependencies in ggml-config.cmake.in. --- ggml/cmake/ggml-config.cmake.in | 14 +- ggml/src/CMakeLists.txt | 17 ++ ggml/src/ggml-cpu/CMakeLists.txt | 14 +- ggml/src/ggml-quants.c | 328 ++++++++++++++++++++----------- 4 files changed, 246 insertions(+), 127 deletions(-) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 91c9d5cd343..23a3066f56d 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -6,6 +6,7 @@ include(CMakeFindDependencyMacro) find_dependency(Threads) if (NOT GGML_SHARED_LIB) + set(GGML_BASE_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_OPTIONS "") @@ -20,7 +21,15 @@ if (NOT GGML_SHARED_LIB) if (GGML_OPENMP_ENABLED) find_dependency(OpenMP) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + set(GGML_OPENMP_INTERFACE_LINK_LIBRARIES "") + if (TARGET OpenMP::OpenMP_C) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C) + endif() + if (TARGET OpenMP::OpenMP_CXX) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_CXX) + endif() + list(APPEND GGML_BASE_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) endif() if (GGML_CPU_HBM) @@ -122,7 +131,8 @@ if(NOT TARGET ggml::ggml) add_library(ggml::ggml-base UNKNOWN IMPORTED) set_target_properties(ggml::ggml-base PROPERTIES - IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}" + INTERFACE_LINK_LIBRARIES "${GGML_BASE_INTERFACE_LINK_LIBRARIES}") set(_ggml_all_targets "") if (NOT GGML_BACKEND_DL) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 3e48860bfc8..c26c3f1470d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -222,6 +222,23 @@ if (GGML_SCHED_NO_REALLOC) target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) endif() +if (GGML_OPENMP) + find_package(OpenMP) + if (OpenMP_FOUND) + set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") + else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") + message(WARNING "OpenMP not found") + endif() +else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") +endif() + +if (GGML_OPENMP_ENABLED) + target_compile_definitions(ggml-base PRIVATE GGML_USE_OPENMP) + target_link_libraries(ggml-base PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) +endif() + add_library(ggml ggml-backend-dl.cpp ggml-backend-reg.cpp) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f3eccff7d72..8c735a045b3 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -72,17 +72,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() - if (GGML_OPENMP) - find_package(OpenMP) - if (OpenMP_FOUND) - set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") - target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) - - target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) - else() - set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") - message(WARNING "OpenMP not found") - endif() + if (GGML_OPENMP_ENABLED) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) + target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) endif() if (GGML_LLAMAFILE) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 15443aa554a..15d231f70c0 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -13,6 +13,10 @@ #include // for qsort #include // for GGML_ASSERT +#ifdef GGML_USE_OPENMP +#include +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -3064,70 +3068,121 @@ void iq2xs_init_impl(enum ggml_type type) { } kmap_q2xs[index] = i; } - int8_t pos[8]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // The neighbour search runs in three passes: + // 1. Parallel: for each i, qsort and count its neighbours into n_per_i, + // and reduce the totals (num_neighbors, num_not_in_map). + // 2. Serial: prefix-sum n_per_i into offsets[], so each i has a + // pre-assigned slice of kneighbors_q2xs to write into. + // 3. Parallel: redo the qsort and write each i's neighbour list at + // offsets[i]. + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq2_data[gindex].neighbours = kneighbors_q2xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - kmap_q2xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q2xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q2xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q2xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int local_counter = offsets[i]; + kmap_q2xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq2xs_free_impl(enum ggml_type type) { @@ -3663,70 +3718,115 @@ void iq3xs_init_impl(int grid_size) { } kmap_q3xs[index] = i; } - int8_t pos[4]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // See explanation of parallelism in iq2xs_init_impl + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq3_data[gindex].neighbours = kneighbors_q3xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - kmap_q3xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q3xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q3xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q3xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int local_counter = offsets[i]; + kmap_q3xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq3xs_free_impl(int grid_size) { From 624bac199e16dcebf8dc64d32e3e80963f1278ed Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:13:21 +0300 Subject: [PATCH 53/55] ggml : bump version to 0.12.1 (ggml/1508) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4aac5094d1c..03020888f97 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 12) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 77ab0a002d35b3f406c9df5fee445d84113729bb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:14:40 +0300 Subject: [PATCH 54/55] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 5a605ba344e..2c680ce9f5d 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -57ea0bc119d722d74594196cc5b494a34dd87be4 +0a37c2167fc5b81830a32d9b1691610180ed86d6 From 9ff9972c06a6642b3c1f930ff8f2f71c4ec2f1e0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:18:31 +0300 Subject: [PATCH 55/55] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 27 +- examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-chat.cpp | 8 +- examples/talk-llama/llama-chat.h | 2 +- examples/talk-llama/llama-context.cpp | 196 ++++++++- examples/talk-llama/llama-context.h | 9 + examples/talk-llama/llama-cparams.h | 4 + examples/talk-llama/llama-ext.h | 16 + examples/talk-llama/llama-graph.cpp | 43 +- examples/talk-llama/llama-graph.h | 6 +- examples/talk-llama/llama-hparams.cpp | 6 + examples/talk-llama/llama-hparams.h | 2 + .../talk-llama/llama-memory-hybrid-iswa.cpp | 14 +- .../talk-llama/llama-memory-hybrid-iswa.h | 1 + examples/talk-llama/llama-memory-hybrid.cpp | 14 +- examples/talk-llama/llama-memory-hybrid.h | 1 + .../talk-llama/llama-memory-recurrent.cpp | 131 +++++- examples/talk-llama/llama-memory-recurrent.h | 9 + examples/talk-llama/llama-memory.h | 3 + examples/talk-llama/llama-model-loader.cpp | 13 +- examples/talk-llama/llama-model-loader.h | 2 +- examples/talk-llama/llama-model-saver.cpp | 2 + examples/talk-llama/llama-model.cpp | 57 ++- examples/talk-llama/llama-model.h | 21 +- examples/talk-llama/llama-vocab.cpp | 125 +++++- examples/talk-llama/llama.h | 11 +- examples/talk-llama/models/afmoe.cpp | 2 +- examples/talk-llama/models/apertus.cpp | 2 +- examples/talk-llama/models/arcee.cpp | 2 +- examples/talk-llama/models/arctic.cpp | 2 +- examples/talk-llama/models/arwkv7.cpp | 2 +- examples/talk-llama/models/baichuan.cpp | 2 +- examples/talk-llama/models/bailingmoe.cpp | 2 +- examples/talk-llama/models/bailingmoe2.cpp | 2 +- examples/talk-llama/models/bloom.cpp | 2 +- examples/talk-llama/models/chameleon.cpp | 2 +- examples/talk-llama/models/chatglm.cpp | 2 +- examples/talk-llama/models/codeshell.cpp | 2 +- examples/talk-llama/models/cogvlm.cpp | 2 +- examples/talk-llama/models/cohere2.cpp | 2 +- examples/talk-llama/models/command-r.cpp | 2 +- examples/talk-llama/models/dbrx.cpp | 2 +- examples/talk-llama/models/deci.cpp | 2 +- examples/talk-llama/models/deepseek.cpp | 2 +- examples/talk-llama/models/delta-net-base.cpp | 164 +++++++- examples/talk-llama/models/dots1.cpp | 2 +- examples/talk-llama/models/dream.cpp | 2 +- examples/talk-llama/models/ernie4-5-moe.cpp | 2 +- examples/talk-llama/models/ernie4-5.cpp | 2 +- examples/talk-llama/models/exaone-moe.cpp | 2 +- examples/talk-llama/models/exaone.cpp | 2 +- examples/talk-llama/models/exaone4.cpp | 2 +- examples/talk-llama/models/falcon-h1.cpp | 2 +- examples/talk-llama/models/falcon.cpp | 2 +- examples/talk-llama/models/gemma.cpp | 2 +- examples/talk-llama/models/gemma2.cpp | 2 +- examples/talk-llama/models/gemma3.cpp | 2 +- examples/talk-llama/models/gemma3n.cpp | 2 +- examples/talk-llama/models/gemma4.cpp | 2 +- examples/talk-llama/models/glm4-moe.cpp | 2 +- examples/talk-llama/models/glm4.cpp | 2 +- examples/talk-llama/models/gpt2.cpp | 2 +- examples/talk-llama/models/gptneox.cpp | 2 +- examples/talk-llama/models/granite-hybrid.cpp | 2 +- examples/talk-llama/models/granite.cpp | 2 +- examples/talk-llama/models/grok.cpp | 2 +- examples/talk-llama/models/grovemoe.cpp | 2 +- examples/talk-llama/models/hunyuan-moe.cpp | 2 +- examples/talk-llama/models/hunyuan-vl.cpp | 2 +- examples/talk-llama/models/internlm2.cpp | 2 +- examples/talk-llama/models/jais.cpp | 2 +- examples/talk-llama/models/jais2.cpp | 2 +- examples/talk-llama/models/jamba.cpp | 2 +- examples/talk-llama/models/lfm2.cpp | 2 +- examples/talk-llama/models/llada-moe.cpp | 2 +- examples/talk-llama/models/llada.cpp | 2 +- examples/talk-llama/models/llama.cpp | 2 +- examples/talk-llama/models/llama4.cpp | 2 +- examples/talk-llama/models/maincoder.cpp | 2 +- examples/talk-llama/models/mamba.cpp | 2 +- examples/talk-llama/models/mimo2.cpp | 2 +- examples/talk-llama/models/minicpm3.cpp | 2 +- examples/talk-llama/models/minimax-m2.cpp | 2 +- examples/talk-llama/models/mistral3.cpp | 2 +- examples/talk-llama/models/models.h | 33 +- examples/talk-llama/models/mpt.cpp | 2 +- examples/talk-llama/models/nemotron-h.cpp | 2 +- examples/talk-llama/models/nemotron.cpp | 2 +- examples/talk-llama/models/olmo.cpp | 2 +- examples/talk-llama/models/olmo2.cpp | 2 +- examples/talk-llama/models/olmoe.cpp | 2 +- examples/talk-llama/models/openai-moe.cpp | 2 +- examples/talk-llama/models/openelm.cpp | 2 +- examples/talk-llama/models/orion.cpp | 2 +- examples/talk-llama/models/paddleocr.cpp | 2 +- examples/talk-llama/models/pangu-embed.cpp | 2 +- examples/talk-llama/models/phi2.cpp | 2 +- examples/talk-llama/models/phi3.cpp | 2 +- examples/talk-llama/models/plamo.cpp | 2 +- examples/talk-llama/models/plamo2.cpp | 2 +- examples/talk-llama/models/plamo3.cpp | 2 +- examples/talk-llama/models/plm.cpp | 2 +- examples/talk-llama/models/qwen.cpp | 2 +- examples/talk-llama/models/qwen2.cpp | 2 +- examples/talk-llama/models/qwen2moe.cpp | 2 +- examples/talk-llama/models/qwen2vl.cpp | 2 +- examples/talk-llama/models/qwen3.cpp | 2 +- examples/talk-llama/models/qwen35.cpp | 321 +++++++++++---- examples/talk-llama/models/qwen35moe.cpp | 373 ++++++++++++++---- examples/talk-llama/models/qwen3moe.cpp | 2 +- examples/talk-llama/models/qwen3next.cpp | 46 +-- examples/talk-llama/models/qwen3vl.cpp | 2 +- examples/talk-llama/models/qwen3vlmoe.cpp | 2 +- examples/talk-llama/models/refact.cpp | 2 +- examples/talk-llama/models/rnd1.cpp | 2 +- examples/talk-llama/models/rwkv6.cpp | 2 +- examples/talk-llama/models/rwkv6qwen2.cpp | 2 +- examples/talk-llama/models/rwkv7.cpp | 2 +- examples/talk-llama/models/seed-oss.cpp | 2 +- examples/talk-llama/models/smallthinker.cpp | 2 +- examples/talk-llama/models/smollm3.cpp | 2 +- examples/talk-llama/models/stablelm.cpp | 2 +- examples/talk-llama/models/starcoder.cpp | 2 +- examples/talk-llama/models/starcoder2.cpp | 2 +- examples/talk-llama/models/step35.cpp | 2 +- examples/talk-llama/models/t5.cpp | 2 +- .../talk-llama/models/wavtokenizer-dec.cpp | 2 +- examples/talk-llama/models/xverse.cpp | 2 +- examples/talk-llama/unicode.cpp | 133 +++++++ 129 files changed, 1593 insertions(+), 395 deletions(-) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 59dde99e362..c9eead18aa3 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -757,14 +757,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index e37d548c98e..89cf16cc37c 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 6554a89b28a..f10397747b0 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -73,7 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, - { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, + { "hunyuan-vl", LLM_CHAT_TEMPLATE_HUNYUAN_VL }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -218,7 +218,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; + return LLM_CHAT_TEMPLATE_HUNYUAN_VL; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -825,8 +825,8 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { - // tencent/HunyuanOCR + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_VL) { + // tencent/HunyuanOCR & tencent/HunyuanVL ss << "<|hy_begin▁of▁sentence|>"; for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 13f936a946c..ea6540c0be7 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -53,7 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, - LLM_CHAT_TEMPLATE_HUNYUAN_OCR, + LLM_CHAT_TEMPLATE_HUNYUAN_VL, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 3d9714ab166..ad36c06667d 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -42,13 +51,22 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; + cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; + cparams.embeddings_pre_norm_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -65,6 +83,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -206,6 +226,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -278,6 +299,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -860,6 +882,42 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1098,13 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); + + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1072,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init && @@ -1241,7 +1319,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1392,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1460,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1622,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_ubatch(); @@ -1689,7 +1783,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1727,8 +1822,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1905,25 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1823,6 +1938,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1893,10 +2009,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +2026,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +2050,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2067,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2096,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2163,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -2121,7 +2256,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3100,7 +3235,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3201,8 +3336,10 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3306,6 +3443,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3347,6 +3491,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } @@ -3436,6 +3584,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 92d1b0cf95a..d03f681d4a1 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 9d359474132..20ec59fe335 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -27,6 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -40,6 +43,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 8ce29d217cb..edfa71c207c 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); + +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fe155c92dea..fc027de8b39 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -500,15 +500,21 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { } void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -534,14 +540,21 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { bool res = true; - res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } - res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -848,6 +861,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -2528,7 +2544,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 5cb1756c6a9..bf6778237e6 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -580,7 +581,8 @@ struct llm_graph_params { ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && ( (!ubatch.token && !other.ubatch.token) || - (!ubatch.embd && !other.ubatch.embd) + (!ubatch.embd && !other.ubatch.embd) || + (ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd) ); // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same @@ -644,6 +646,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +675,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 10e6b459797..72f5c2fea72 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_base()->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h index 807c8aac96c..c9d3f9f57c5 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.h +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 4ce1af592c1..33b3b395e0c 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index c07f1d969cb..ec5dc5835dd 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -388,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } } if (ubatch.n_tokens == 0) { @@ -703,6 +737,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq GGML_UNUSED(flags); std::vector> cell_ranges; // ranges, from inclusive, to exclusive + std::vector> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -712,6 +747,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -726,7 +790,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } - if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); } @@ -737,10 +801,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -762,6 +832,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { @@ -804,7 +882,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Write each range of cells of r_size_row length + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -825,7 +904,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Write each range of S tensor rows + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -852,9 +932,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; @@ -1163,5 +1242,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 47f01d73912..b13b7b748f5 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,14 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4a157b91fdb..4ad1612e45b 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index e83056557bf..528e4c9c069 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -393,6 +393,8 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model->output); add_tensor(model->output_b); add_tensor(model->output_norm_enc); + add_tensor(model->output_s); + add_tensor(model->output_in_s); add_tensor(model->cls); add_tensor(model->cls_b); add_tensor(model->cls_out); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index ff30a2ae7a6..0d21b2a53c5 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1334,6 +1334,12 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // input scales if (!layer.wq_in_s && layer.wq) { @@ -1393,11 +1399,30 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_in_s && layer.ssm_beta) { layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_in_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_in_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + // output scales + if (output && output->type == GGML_TYPE_NVFP4) { + // weight scale + if (!output_s) { + output_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "scale"), {1}, TENSOR_NOT_REQUIRED); + } + // input scale + if (!output_in_s) { + output_in_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "input_scale"), {1}, TENSOR_NOT_REQUIRED); + } } } - ml.done_getting_tensors(); + GGML_ASSERT(!(output && tok_embd && + strcmp(output->name, tok_embd->name) == 0 && + output->type == GGML_TYPE_NVFP4)); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { @@ -1934,6 +1959,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1942,8 +1973,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1958,6 +1990,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -1975,6 +2015,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -1993,6 +2034,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2000,6 +2042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2011,6 +2054,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2026,7 +2074,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2043,7 +2091,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2146,6 +2194,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d63c689185a..398a0aa725c 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -202,12 +202,16 @@ struct llama_layer_shortconv { }; struct llama_layer_nextn { - struct ggml_tensor * eh_proj = nullptr; - struct ggml_tensor * embed_tokens = nullptr; - struct ggml_tensor * enorm = nullptr; - struct ggml_tensor * hnorm = nullptr; - struct ggml_tensor * shared_head_head = nullptr; - struct ggml_tensor * shared_head_norm = nullptr; + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * eh_proj_s = nullptr; + struct ggml_tensor * eh_proj_in_s = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_head_s = nullptr; + struct ggml_tensor * shared_head_head_in_s = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; }; struct llama_layer { @@ -533,6 +537,11 @@ struct llama_model { struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + + // NVFP4 per-tensor scale2, input_scale for LM head + struct ggml_tensor * output_s = nullptr; + struct ggml_tensor * output_in_s = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index f43cf546ca0..a5cf148b268 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -530,6 +530,8 @@ struct llm_tokenizer_bpe : llm_tokenizer { struct llm_tokenizer_bpe_session { llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + virtual ~llm_tokenizer_bpe_session() = default; + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } @@ -567,7 +569,7 @@ struct llm_tokenizer_bpe_session { } } - void tokenize(const std::string & text, std::vector & output) { + virtual void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); @@ -1579,6 +1581,88 @@ struct llm_tokenizer_plamo2_session { const llm_tokenizer_plamo2 & tokenizer; }; +// reserved suffix (U+E000) that keeps DNA k-mers distinct from identical +// base-vocab BPE tokens (e.g. CCCCCC) in token_to_id; erased from id_to_token +// text at load +static const std::string dna_kmer_marker = "\xee\x80\x80"; + +struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { + llm_tokenizer_hybriddna_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector & output) override { + static const std::string open_tag = ""; + static const std::string close_tag = ""; + + const auto dna_begin_id = vocab.text_to_token(open_tag); + const auto dna_end_id = vocab.text_to_token(close_tag); + const auto dna_oov_id = vocab.text_to_token(""); + + // Fall back to plain BPE if the DNA pieces aren't in the vocab. + if (dna_begin_id == LLAMA_TOKEN_NULL || dna_end_id == LLAMA_TOKEN_NULL || dna_oov_id == LLAMA_TOKEN_NULL) { + llm_tokenizer_bpe_session::tokenize(text, output); + return; + } + + const size_t k = 6; + size_t pos = 0; + + while (pos < text.size()) { + const size_t start = text.find(open_tag, pos); + if (start == std::string::npos) { + if (pos < text.size()) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos), output); + } + break; + } + if (start > pos) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos, start - pos), output); + } + output.push_back(dna_begin_id); + + const size_t content_start = start + open_tag.size(); + const size_t end = text.find(close_tag, content_start); + const size_t content_end = (end == std::string::npos) ? text.size() : end; + + emit_dna_kmers(text.substr(content_start, content_end - content_start), k, dna_oov_id, output); + + if (end == std::string::npos) { + break; + } + output.push_back(dna_end_id); + pos = end + close_tag.size(); + } + } + +private: + void emit_dna_kmers(const std::string & raw, size_t k, llama_token oov_id, std::vector & output) { + std::string seq = raw; + for (char & c : seq) { + if (c >= 'a' && c <= 'z') { + c = char(c - 32); + } + } + + // k-mers carry the reserved marker suffix; a non-ACGT k-mer simply + // isn't in the vocab and falls back to + auto kmer_token = [&](const std::string & kmer) { + const auto tok = vocab.text_to_token(kmer + dna_kmer_marker); + return tok != LLAMA_TOKEN_NULL ? tok : oov_id; + }; + + size_t i = 0; + for (; i + k <= seq.size(); i += k) { + output.push_back(kmer_token(seq.substr(i, k))); + } + if (i < seq.size()) { + std::string kmer = seq.substr(i); + kmer.append(k - kmer.size(), 'A'); + output.push_back(kmer_token(kmer)); + } + } + + const llama_vocab & vocab; +}; + // // impl // @@ -1808,7 +1892,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2266,6 +2350,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } GGML_ASSERT(id_to_token.size() == token_to_id.size()); + // hybriddna: the marker suffix kept k-mer ids distinct in token_to_id; erase + // it from id_to_token so the k-mers detokenize to the bare DNA sequence. The + // k-mers are the block right after , so only scan from there. + if (tokenizer_model == "hybriddna") { + const auto idx = token_to_id.find(""); + if (idx != token_to_id.end()) { + auto it = id_to_token.begin() + idx->second + 1; + for (; it != id_to_token.end(); ++it) { + std::string & text = it->text; + if (text.size() > dna_kmer_marker.size() + && text.compare(text.size() - dna_kmer_marker.size(), dna_kmer_marker.size(), dna_kmer_marker) == 0) { + text.erase(text.size() - dna_kmer_marker.size()); + } + } + } + } + init_tokenizer(type); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -3144,11 +3245,19 @@ std::vector llama_vocab::impl::tokenize( } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get())); // it calls some other methods that are not exist in llm_tokenizer, // here just cast it to bpe tokenizer object + const llm_tokenizer_bpe * tok_bpe = static_cast(tokenizer.get()); + + std::unique_ptr session; + if (vocab.get_tokenizer_model() == "hybriddna") { + session = std::make_unique(vocab, *tok_bpe); + } else { + session = std::make_unique(vocab, *tok_bpe); + } + if (add_special) { - session.append_bos(output); + session->append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -3161,15 +3270,15 @@ std::vector llama_vocab::impl::tokenize( #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif - session.tokenize(text, output); + session->tokenize(text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - session.append(fragment.token, output); + session->append(fragment.token, output); } } if (add_special) { - session.append_eos(output); - session.check_double_bos_eos(output); + session->append_eos(output); + session->check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 308e8ba9dbd..e8374c53b70 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -333,9 +338,11 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -530,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @@ -866,7 +874,8 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 -// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load) +// Keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load). +// Getting the state for a seq_id with this flag invalidates all prior states gotten for that seq_id with this flag. #define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 typedef uint32_t llama_state_seq_flags; diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 602e3176afd..a7c77ee5d28 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -277,7 +277,7 @@ llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 136ff702957..bec7136521c 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -160,7 +160,7 @@ llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 70e86d41130..d086c4717ff 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -148,7 +148,7 @@ llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index d8653a44639..27deadffeb7 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -171,7 +171,7 @@ llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 79aa8c90899..9bd04127b25 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -193,7 +193,7 @@ llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4e55290e4e5..4d26081cd5d 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -146,7 +146,7 @@ llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 030dd4f42a4..fe1ae10864b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -171,7 +171,7 @@ llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index e7fe3d5b45a..2f0d44a6259 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -210,7 +210,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b600fb0c954..30b0f3d07d0 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -142,7 +142,7 @@ llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 8510b9e29f8..4bceaefd63b 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -181,7 +181,7 @@ llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_with_img_logits", -1); // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index e898eff7939..6766fa71c15 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -151,7 +151,7 @@ llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e9e85d96713..274dd3342a7 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -143,7 +143,7 @@ llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 79236121bd5..2e231bb3f93 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -150,7 +150,7 @@ llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index 12edbae1094..a514cf88fc6 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -146,7 +146,7 @@ llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index decb89f547b..adf7fcaa20f 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -131,7 +131,7 @@ llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index bce6b04bcf9..af71c775365 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -145,7 +145,7 @@ llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 9f1a959c32c..567e3535276 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -181,7 +181,7 @@ llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index c7946059662..f52ec9518b6 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -185,7 +185,7 @@ llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 6bc989c9509..4f4c7cac7a8 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,162 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + const auto mem_size = mctx_cur->get_size(); + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + + const size_t row_size = ggml_row_size(conv_states_all->type, row_count); + + if (cparams.n_rs_seq == 0) { + const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; + const int64_t s_slot = 0; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + cb(conv_state_last, "conv_state_last", il); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, conv_states_all, + row_count, n_seqs, conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + cb(conv_state_update, "conv_state_update", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } else { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are + // inside the same ubatch. currently with `split_equal()` this is not correct + + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + for (int64_t t = 1; t <= K; ++t) { + const int64_t s_idx = std::max(0, conv_input->ne[0] - conv_states->ne[0] - K + t); + const int64_t s_slot = K - t; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, + conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + const bool keep = cparams.n_rs_seq > 0; + + if (!keep) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = cparams.n_rs_seq + 1; + + // TODO: remove pad + simplify + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + if (n_seq_tokens > 1) { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + } else { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); + } + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 93cbcf9d931..435d27281c6 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -183,7 +183,7 @@ llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 60a3f0ec285..12ac6f1ce88 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -128,7 +128,7 @@ llama_model_dream::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 2bd01a2c512..8d9ff138676 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -124,7 +124,7 @@ llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index fa989fe92cd..9b39c605e35 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -155,7 +155,7 @@ llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 54bb3ca86b3..76d91982fc5 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -237,7 +237,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 75d5f60631c..c7e9960d718 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -127,7 +127,7 @@ llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 5506e76424d..499e22dde81 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -163,7 +163,7 @@ llama_model_exaone4::graph::graph(const llama_model & model, const llm_gra res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index d353befdb8e..94b65a3c7c9 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -200,7 +200,7 @@ llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 75f2cfef560..ad546ef2db5 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -152,7 +152,7 @@ llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 06731670007..1519682fdf6 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -130,7 +130,7 @@ llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index 6255bf740fc..ae3f9ffb530 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -163,7 +163,7 @@ llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // final logit soft-capping cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index ee510fe38b0..63a2b380e71 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -207,7 +207,7 @@ llama_model_gemma3::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 881499b0ca7..6ec3a006081 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -296,7 +296,7 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); { // final logit soft-capping diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index f45ae4cad59..4f9d8b18bc7 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -380,7 +380,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 45886b51ac1..27654b8cba3 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -275,7 +275,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index d6ef76e26d6..7c242fed298 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -185,7 +185,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index ba49c31b56b..e2dcc8b1521 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -138,7 +138,7 @@ llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 33ebe2d8800..443e35addf2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -209,7 +209,7 @@ llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 12e4790ae24..27f6706ea10 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -186,7 +186,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits if (hparams.f_logit_scale) { diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 5e7c7b68181..cda4aa231fa 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -145,7 +145,7 @@ llama_model_granite::graph::graph( res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 0bc49d00206..7c46ec1c0f2 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -206,7 +206,7 @@ llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index feef815165b..1cab75adc7f 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -184,7 +184,7 @@ llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 44af42412f7..deb3c9671f3 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -179,7 +179,7 @@ llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp index 5fb9154bec0..da9bb74de7e 100644 --- a/examples/talk-llama/models/hunyuan-vl.cpp +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -181,7 +181,7 @@ llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f0c5580a6f4..f9ee37a24b6 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -129,7 +129,7 @@ llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index a6451dca095..2ba162605f1 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -123,7 +123,7 @@ llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index ad59b953e8d..8966131441c 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -152,7 +152,7 @@ llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index e1b8d137e38..84ea63c3136 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -189,7 +189,7 @@ llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index df6a8028736..29081344b24 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -262,7 +262,7 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index b60f67f6c4b..9722dde9f17 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -153,7 +153,7 @@ llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index fa21c5fe32c..58b2c466e17 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -147,7 +147,7 @@ llama_model_llada::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 8ddb5936820..cef66d054b0 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -235,7 +235,7 @@ llama_model_llama::graph::graph(const llama_model & model, const llm_grap if constexpr (!embed) { // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 899611d53f6..0ff5376d571 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -260,7 +260,7 @@ llama_model_llama4::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 3dbd82fd362..84cfe399027 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -141,7 +141,7 @@ llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index b7708d7fdd1..887a1fa509a 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -128,7 +128,7 @@ llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index 71996616611..d0295ec116f 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -231,7 +231,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index ff5eb6ffa5f..1ffc54fa7c6 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -251,7 +251,7 @@ llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "lmhead_scaling", -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 0dee8934692..22e291d73a3 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -158,7 +158,7 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 708da49af1f..4e6ebef82cb 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -222,7 +222,7 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 6d5f18a8e20..7e551eb965b 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,29 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -1739,6 +1762,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,6 +1808,10 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index cfc60e8de29..0229d20ed36 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -161,7 +161,7 @@ llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 865461f61db..a82f9c170b4 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -174,7 +174,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 0c72ed297aa..5d4a3b5c69e 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -140,7 +140,7 @@ llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 161035e72bc..cfcf17bcb03 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -133,7 +133,7 @@ llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 9633f269965..7cc262f5504 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -198,7 +198,7 @@ llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 4bb9013054c..7976ae44a51 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -164,7 +164,7 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 13a590ce646..15b6c8c1205 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -160,7 +160,7 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index b4128e116e7..9f76350fd4d 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -162,7 +162,7 @@ llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 7ace0a5139d..bcb4bbba4b1 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -132,7 +132,7 @@ llama_model_orion::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 1c0eadefa98..d39220bd778 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -98,7 +98,7 @@ llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 41b7e2ac23e..7593f879b24 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -148,7 +148,7 @@ llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index a333602c72d..8f3ed5f7b7d 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -130,7 +130,7 @@ llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 0a65e91fefa..f8a4a4d5aa5 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -179,7 +179,7 @@ llama_model_phi3::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cb(cur, "result_output_no_bias", -1); diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 4c16c20a0d4..c7ed1211c31 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -127,7 +127,7 @@ llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 29c8702606a..b713889fe72 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -185,7 +185,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); // Explicitly mark as output tensor to ensure proper backend assignment diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 849f1579e63..29f3e803d68 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -186,7 +186,7 @@ llama_model_plamo3::graph::graph(const llama_model & model, const llm_grap cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 57f5995103b..ce050919e6a 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -204,7 +204,7 @@ llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index cdc076cdf77..00467dbad7d 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -131,7 +131,7 @@ llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 6320458a13b..a5147460bae 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -141,7 +141,7 @@ llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7587c802c68..7cb03859deb 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -184,7 +184,7 @@ llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 1a40fa89be4..d79db682cd4 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -134,7 +134,7 @@ llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index fa656c84ea0..41b97fed956 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -147,7 +147,7 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index f276be61ba8..04ecc18fcdc 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -43,50 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +208,13 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -167,7 +222,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -297,8 +352,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -328,41 +381,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -413,7 +439,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -424,18 +450,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -471,3 +486,151 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // hparams.n_layer includes both main model layers and MTP layers. The MTP + // layer is stored immediately after the main layers in model.layers[]. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index cf05dc9d61c..dc24f6ed537 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -46,60 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +231,13 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -180,7 +245,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -310,8 +375,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -341,41 +404,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -426,7 +462,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -437,18 +473,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -525,3 +550,183 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 4440b83aa45..a4f8e1379c9 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -168,7 +168,7 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cb1b4814caf..1d873427db5 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -176,7 +176,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 7871f8f7952..5defd893944 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -163,7 +163,7 @@ llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index b99143c8908..5b77df57122 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -180,7 +180,7 @@ llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index f14f10917ff..bf3949a9092 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -150,7 +150,7 @@ llama_model_refact::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index 325ee73ba5c..ca8e009615e 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -167,7 +167,7 @@ llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 2944711acec..ba2a9dfa0db 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -176,7 +176,7 @@ llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 6f7d1f5722f..566b8cdcb54 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -158,7 +158,7 @@ llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index b205e3935e1..7574b252621 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -202,7 +202,7 @@ llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 83e114740b6..806cba574be 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -141,7 +141,7 @@ llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 3214e7cbad3..4231cccc666 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -178,7 +178,7 @@ llama_model_smallthinker::graph::graph(const llama_model & model, const ll res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 7adaf34c534..90e7d473eaf 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -143,7 +143,7 @@ llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 8f613e55947..4da7f7aefcf 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -163,7 +163,7 @@ llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 58cf0ac0edc..e131af058bc 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -135,7 +135,7 @@ llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 45dae0602d4..9c207c02885 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -148,7 +148,7 @@ llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index c4789752d21..3b68e68707a 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -261,7 +261,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 27a0711ba41..73e32741406 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -265,7 +265,7 @@ llama_model_t5::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index a873e5d2e8f..214fed99bad 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -253,7 +253,7 @@ llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_ LLM_NORM, -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index e4d111e622a..d6d1c7a2e5d 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -126,7 +126,7 @@ llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index dc13e53f09f..b02ecdc930f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -605,6 +605,136 @@ static std::vector unicode_regex_split_custom_qwen2(const std::string & return bpe_offsets; } +// Qwen3.5 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +// Compared to Qwen2, letter-runs also consume Unicode combining marks (\p{M}): [\p{L}\p{M}]+ instead of \p{L}+ +static std::vector unicode_regex_split_custom_qwen35(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || flags.is_accent_mark || _get_flags(pos + 1).is_accent_mark || _get_flags(pos+1).is_letter) { + pos++; + while (_get_flags(pos).is_letter || _get_flags(pos).is_accent_mark) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{M}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -929,6 +1059,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen35(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);