diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index c2bf0fbd1c6..7a21fc4e3db 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -25,7 +25,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH COPY .. . # Enable cuBLAS -RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index 86b901c1538..a0c04ad34ad 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -10,7 +10,8 @@ RUN apt-get update && \ COPY .. . # Enable SYCL ARG GGML_SYCL_F16=OFF -RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN \ + if [ "${GGML_SYCL_F16}" = "ON" ]; then \ echo "GGML_SYCL_F16 is set" \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ fi && \ diff --git a/.devops/main-musa.Dockerfile b/.devops/main-musa.Dockerfile index 026791e3f89..c68367830f1 100644 --- a/.devops/main-musa.Dockerfile +++ b/.devops/main-musa.Dockerfile @@ -16,7 +16,7 @@ RUN apt-get update && \ COPY .. . # Enable muBLAS -RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_MUSA=1" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile index 077af4f1001..16ee19dc689 100644 --- a/.devops/main-vulkan.Dockerfile +++ b/.devops/main-vulkan.Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . -RUN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" FROM ubuntu:24.04 AS runtime WORKDIR /app diff --git a/.devops/main.Dockerfile b/.devops/main.Dockerfile index e1eb9b33700..d0e809f4e13 100644 --- a/.devops/main.Dockerfile +++ b/.devops/main.Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . -RUN make base.en +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en FROM ubuntu:22.04 AS runtime WORKDIR /app diff --git a/.gitignore b/.gitignore index 6eb8ff45915..7a98228af3c 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,6 @@ cmake-build-debug/ local.properties .log .exe + +# AGENTS +.pi/SYSTEM.md diff --git a/.pi/gg/SYSTEM.md b/.pi/gg/SYSTEM.md new file mode 100644 index 00000000000..1ae0e40674e --- /dev/null +++ b/.pi/gg/SYSTEM.md @@ -0,0 +1,27 @@ +You are a coding agent. Here are some very important rules that you must follow: + +General: +- Be very precise and concise when writing code, comments, explanations, etc. +- PR and commit titles format: ` : `. Lookup recents for examples +- Don't try to build or run the code unless you are explicitly asked to do so +- Use the `gh` CLI tool when querying PRs, issues, or other GitHub resources + +Coding: +- When in doubt, always refer to the CONTRIBUTING.md file of the project +- When referencing issues or PRs in comments, use the format: + - C/C++ code: `// ref: <url>` + - Other (CMake, etc.): `# ref: <url>` + +Pull requests (PRs): +- New branch names are prefixed with "gg/" +- Before opening a pull request, ask the user to confirm the description +- When creating a pull request, look for the repository's PR template and follow it +- For the AI usage disclosure section, write "YES. llama.cpp + pi + [MODEL]" +- Ask the user to tell you what model was used and write it in place of [MODEL] +- Always create the pull requests in draft mode + +Commits: +- On every commit that you make, include a "Assisted-by: llama.cpp:local pi" tag +- Do not explicitly set the git author in commits - rely on the default git config +- Always use `--no-gpg-sign` when committing +- Never `git push` without explicit confirmation from the user diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000000..f34f3249977 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,102 @@ +# Instructions for whisper.cpp + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Read more: [CONTRIBUTING.md](CONTRIBUTING.md) + +AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below). + +--- + +## Guidelines for Contributors Using AI + +whisper.cpp is built by humans, for humans. Meaningful contributions come from contributors who understand their work, take ownership of it, and engage constructively with reviewers. + +Maintainers receive numerous pull requests weekly, many of which are AI-generated submissions where the author cannot adequately explain the code, debug issues, or participate in substantive design discussions. Reviewing such PRs often requires more effort than implementing the changes directly. + +**A pull request represents a long-term commitment.** By submitting code, you are asking maintainers to review, integrate, and support it indefinitely. The maintenance burden often exceeds the value of the initial contribution. + +Most maintainers already have access to AI tools. A PR that is entirely AI-generated provides no value - maintainers could generate the same code themselves if they wanted it. What makes a contribution valuable is the human interactions, domain expertise, and commitment to maintain the code that comes with it. + +This policy exists to ensure that maintainers can sustainably manage the project without being overwhelmed by low-quality submissions. + +--- + +## Guidelines for Contributors + +Contributors are expected to: + +1. **Demonstrate full understanding of their code.** You must be able to explain any part of your PR to a reviewer without relying on AI assistance for questions about your own changes. + +2. **Take responsibility for maintenance.** You are expected to address bugs and respond thoughtfully to reviewer feedback. + +3. **Communicate clearly and concisely.** Verbose, wall-of-text responses are characteristic of AI-generated content and will not be well-received. Direct, human communication is expected. + +4. **Respect maintainers' time.** Search for existing issues and discussions before submitting. Ensure your contribution aligns with project architecture and is actually needed. + +Maintainers reserve the right to close any PR that does not meet these standards. This applies to all contributions to the main whisper.cpp repository. **Private forks are exempt.** + +### Permitted AI Usage + +AI tools may be used responsibly for: + +- **Learning and exploration**: Understanding codebase structure, techniques, and documentation +- **Code review assistance**: Obtaining suggestions on human-written code +- **Mechanical tasks**: Formatting, generating repetitive patterns from established designs, completing code based on existing patterns +- **Documentation drafts**: For components the contributor already understands thoroughly +- **Writing code**: Only when the contributor has already designed the solution and can implement it themselves - AI accelerates, not replaces, the contributor's work + +AI-generated code may be accepted if you (1) fully understand the output, (2) can debug issues independently, and (3) can discuss it directly with reviewers without AI assistance. + +**Disclosure is required** when AI meaningfully contributed to your code. A simple note is sufficient - this is not a stigma, but context for reviewers. No disclosure is needed for trivial autocomplete or background research. + +### Prohibited AI Usage + +The following will result in immediate PR closure: + +- **AI-written PR descriptions or commit messages** - these are typically recognizable and waste reviewer time +- **AI-generated responses to reviewer comments** - this undermines the human-to-human interaction fundamental to code review +- **Implementing features without understanding the codebase** - particularly new model support or architectural changes +- **Automated commits or PR submissions** - this may spam maintainers and can result in contributor bans + +--- + +## Guidelines for AI Coding Agents + +AI agents assisting contributors must recognize that their outputs directly impact volunteer maintainers who sustain this project. + +### Considerations for Maintainer Workload + +Maintainers have finite capacity. Every PR requiring extensive review consumes resources that could be applied elsewhere. Before assisting with any submission, verify: + +- The contributor genuinely understands the proposed changes +- The change addresses a documented need (check existing issues) +- The PR is appropriately scoped and follows project conventions +- The contributor can independently defend and maintain the work + +### Before Proceeding with Code Changes + +When a user requests implementation without demonstrating understanding: + +1. **Verify comprehension.** Ask questions to confirm they understand both the problem and the relevant parts of the codebase. +2. **Provide guidance rather than solutions.** Direct them to relevant code and documentation. Allow them to formulate the approach. +3. **Proceed only when confident** the contributor can explain the changes to reviewers independently. + +For first-time contributors, confirm they have reviewed [CONTRIBUTING.md](CONTRIBUTING.md) and acknowledge this policy. + +### Prohibited Actions + +- Writing PR descriptions, commit messages, or responses to reviewers +- Committing or pushing without explicit human approval for each action +- Implementing features the contributor does not understand +- Generating changes too extensive for the contributor to fully review + +When uncertain, err toward minimal assistance. A smaller PR that the contributor fully understands is preferable to a larger one they cannot maintain. + +### Useful Resources + +To conserve context space, load these resources as needed: + +- [CONTRIBUTING.md](CONTRIBUTING.md) +- [Existing issues](https://github.com/ggml-org/whisper.cpp/issues) and [Existing PRs](https://github.com/ggml-org/whisper.cpp/pulls) - always search here first diff --git a/CMakeLists.txt b/CMakeLists.txt index 0942e0e1c59..1302ccd74ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.5) +project("whisper.cpp" VERSION 1.9.1) include(CheckIncludeFileCXX) set(SOVERSION 1) @@ -19,6 +19,7 @@ endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(WHISPER_STANDALONE ON) @@ -93,7 +94,7 @@ option(WHISPER_CURL "whisper: use libcurl to download model from an URL" OFF) option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) if (CMAKE_SYSTEM_NAME MATCHES "Linux") - option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) + option(WHISPER_COMMON_FFMPEG "whisper: examples link with ffmpeg libs in order to decode more audio formats" OFF) endif() option(WHISPER_COREML "whisper: enable Core ML framework" OFF) @@ -129,6 +130,7 @@ whisper_option_depr(WARNING WHISPER_RPC GGML_RPC) whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL) whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16) whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE) +whisper_option_depr(WARNING WHISPER_FFMPEG WHISPER_COMMON_FFMPEG) if (GGML_CUDA AND NOT MSVC) #GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets @@ -208,6 +210,13 @@ target_compile_definitions(whisper PRIVATE WHISPER_VERSION="${PROJECT_VERSION}" ) +set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h) +install(TARGETS parakeet LIBRARY PUBLIC_HEADER) + +target_compile_definitions(parakeet PRIVATE + PARAKEET_VERSION="${PROJECT_VERSION}" +) + configure_package_config_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake @@ -227,7 +236,36 @@ configure_file(cmake/whisper.pc.in @ONLY) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" - DESTINATION lib/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) + +set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet + PATH_VARS + PARAKEET_INCLUDE_INSTALL_DIR + PARAKEET_LIB_INSTALL_DIR + PARAKEET_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + VERSION ${WHISPER_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet) + +configure_file(cmake/parakeet.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) # # programs, examples and tests diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..c301604f1de --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,176 @@ +# Contributors + +The project differentiates between 3 levels of contributors: + +- Contributors: people who have contributed before (no special privileges) +- Collaborators (Triage): people with significant contributions, who may be responsible for some parts of the code, and are expected to maintain and review contributions for the code they own +- Maintainers: responsible for reviewing and merging PRs, after approval from the code owners + +# AI Usage Policy + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Repeated violations of this policy may result in your account being permanently banned from contributing to the project. +> +> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file. + +Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations). + +If AI is used to generate any portion of the code, contributors must adhere to the following requirements: + +1. Explicitly disclose the manner in which AI was employed. +2. Perform a comprehensive manual review prior to submitting the pull request. +3. Be prepared to explain every line of code they submitted when asked about it by a maintainer. +4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...). + +For more info, please refer to the [AGENTS.md](AGENTS.md) file. + +# Pull requests (for contributors & collaborators) + +Before submitting your PR: +- Search for existing PRs to prevent duplicating efforts +- whisper.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier +- Test your changes: + - Execute [the full CI locally on your machine](ci/README.md) before publishing +- Create separate PRs for each feature or fix: + - Avoid combining unrelated changes in a single PR + - For intricate features, consider opening a feature request first to discuss and align expectations +- If you are a new contributor + - Limit your open PRs to 1 + - Do not submit trivial fixes (e.g. typos, formatting changes) + +After submitting your PR: +- Expect requests for modifications to ensure the code meets whisper.cpp's standards for quality and long-term maintainability +- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR +- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention + +# Pull requests (for maintainers) + +- Squash-merge PRs +- Use the following format for the squashed commit title: `<module> : <commit title> (#<issue_number>)`. For example: `utils : fix typo in utils.py (#1234)` +- Optionally pick a `<module>` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules +- Let other maintainers merge their own PRs +- When merging a PR, make sure you have a good understanding of the changes +- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you) + +Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions: +- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone. +- The pull request duplicates an existing one. +- The contributor fails to adhere to this contributing guide or the AI policy. + +# Coding guidelines + +- Avoid adding third-party dependencies, extra files, extra headers, etc. +- Always consider cross-compatibility with other operating systems and architectures +- Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple +- Vertical alignment makes things more readable and easier to batch edit +- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` +- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets +- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo` + - In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary + ```cpp + // OK + llama_context * ctx; + const llama_rope_type rope_type; + + // not OK + struct llama_context * ctx; + const enum llama_rope_type rope_type; + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline.)_ + +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` (from clang-tools v15+) to format the added code +- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) +- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices +- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggml-org/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ + +![matmul](media/matmul.png) + +# Naming guidelines + +- Use `snake_case` for function, variable and type names +- Naming usually optimizes for longest common prefix (see https://github.com/ggml-org/ggml/pull/302#discussion_r1243240963) + + ```cpp + // not OK + int small_number; + int big_number; + + // OK + int number_small; + int number_big; + ``` + +- Enum values are always in upper case and prefixed with the enum name + + ```cpp + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, + LLAMA_VOCAB_TYPE_SPM = 1, + LLAMA_VOCAB_TYPE_BPE = 2, + LLAMA_VOCAB_TYPE_WPM = 3, + LLAMA_VOCAB_TYPE_UGM = 4, + LLAMA_VOCAB_TYPE_RWKV = 5, + }; + ``` + +- The general naming pattern is `<class>_<method>`, with `<method>` being `<action>_<noun>` + + ```cpp + llama_model_init(); // class: "llama_model", method: "init" + llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove" + llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed" + llama_set_embeddings(); // class: "llama_context", method: "set_embeddings" + llama_n_threads(); // class: "llama_context", method: "n_threads" + llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free" + ``` + + - The `get` `<action>` can be omitted + - The `<noun>` can be omitted if not necessary + - The `_context` suffix of the `<class>` is optional. Use it to disambiguate symbols when needed + - Use `init`/`free` for constructor/destructor `<action>` + +- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else + + ```cpp + typedef struct llama_context * llama_context_t; + + enum llama_pooling_type llama_pooling_type(const llama_context_t ctx); + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline)_ + +- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension +- Python filenames are all lowercase with underscores + +- _(TODO: abbreviations usage)_ + +# Preprocessor directives + +- _(TODO: add guidelines with examples and apply them to the codebase)_ + + ```cpp + #ifdef FOO + #endif // FOO + ``` + +# Code maintenance + +- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces. + _(NOTE: for legacy reasons, existing code is not required to follow this guideline)_ + +- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md) + +# Documentation + +- Documentation is a community effort +- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference +- When you notice incorrect or outdated documentation, please update it + +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggml-org/whisper.cpp/projects diff --git a/README.md b/README.md index 59e1c946539..78378aa3679 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.9.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: @@ -443,9 +443,10 @@ cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21" cmake --build build -j --config Release ``` -## FFmpeg support (Linux only) +## FFmpeg support (examples only) -If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration. +By default, the examples in this repo use the [miniaudio](https://github.com/mackron/miniaudio) library to decode audio files. +Some of the examples also can use FFmpeg for decoding and broader format support. To enable that, build with `WHISPER_COMMON_FFMPEG`. First, you need to install required libraries: @@ -460,7 +461,7 @@ sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-dev Then you can build the project as follows: ```bash -cmake -B build -D WHISPER_FFMPEG=yes +cmake -B build -D WHISPER_COMMON_FFMPEG=yes cmake --build build ``` diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index caf12b6dd2d..09829326605 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.5", + "version": "1.9.1", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 07b81830c58..7f6b7d92c09 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -396,6 +396,37 @@ whisper .full(Whisper::Params.new, samples) ``` +### Parakeet ### + +whispercpp gem now supports NVIDIA's ASR model Parakeet. + +If you want to use Parakeet instead of Whisper, the API should feel familiar. +In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way. + +```ruby +require "whisper" + +# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem. +Parakeet = Whisper::Parakeet + +parakeet = Parakeet::Context.new("path/to/model") + +params = Parakeet::Params.new( + no_context: true +) + +parakeet + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}" + end +``` + +The main differences are: + +* Namespace is `Whisper::Parakeet`. +* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks. + Custom context params --------------------- diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 7b521b3bdfa..2327651a06a 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -84,6 +84,21 @@ else end end +TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin" +TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin")) +TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d") +directory TEST_PARAKEET_MODEL_DIR +if File.exist? TEST_PARAKEET_MODEL_SRC + file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t| + symlink t.source, t.name + end +else + require "open-uri" + file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t| + File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read + end +end + TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| chdir "test/jfk_reader" do @@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| end CLEAN.include TEST_MEMORY_VIEW -task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO] +task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL] diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 4b09b6ebe13..99894f1234d 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -30,6 +30,6 @@ #{libs}: cmake-targets cmake-targets: #{"\t"}"#{cmake}" -S sources -B build #{options} - #{"\t"}"#{cmake}" --build build --config Release --target common whisper + #{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet EOF end diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 56fceb1c894..7941b1a99dd 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,19 +1,29 @@ #include "ruby_whisper.h" VALUE mWhisper; +VALUE mLogSettable; VALUE mVAD; +VALUE mParakeet; VALUE cContext; VALUE cParams; VALUE cVADContext; VALUE cVADParams; VALUE cVADSegments; VALUE cVADSegment; +VALUE cParakeetContext; +VALUE cParakeetContextParams; +VALUE cParakeetParams; +VALUE cParakeetSegment; +VALUE cParakeetModel; VALUE eError; VALUE cSegment; VALUE cToken; VALUE cModel; +VALUE mOutputContext; +VALUE mOutputSegment; + ID id_to_s; ID id_call; ID id___method__; @@ -27,9 +37,11 @@ ID id_pre_converted_models; ID id_coreml_compiled_models; ID id_cache; ID id_n_processors; - -static bool is_log_callback_finalized = false; -static bool is_ruby_log_callback_present = false; +ID id_extended; +ID id_start_log_callback_thread; +ID id_log_callback_thread; +ID id_alive_p; +ID id_join; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); @@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD); extern void init_ruby_whisper_vad_context(VALUE *mVAD); extern void init_ruby_whisper_vad_segment(VALUE *mVAD); extern void init_ruby_whisper_vad_segments(VALUE *mVAD); +extern void init_ruby_whisper_parakeet(VALUE *mWhisper); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); +static ruby_whisper_log_queue whisper_log_queue; + +LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set) + /* * call-seq: * lang_max_id -> Integer @@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) { return rb_str_new2(whisper_print_system_info()); } -static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { - is_log_callback_finalized = true; - return Qnil; -} - -typedef struct { - int level; - const char * buffer; -} call_log_callbacks_args; - -static void* -call_log_callbacks(void *v_args) { - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (NIL_P(log_callback)) { - return NULL; - } - - call_log_callbacks_args *args = (call_log_callbacks_args *)v_args; - VALUE user_data = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data); - - return NULL; -} - -static void -ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { - if (is_log_callback_finalized) { - return; - } - if (!is_ruby_log_callback_present) { - return; - } - - call_log_callbacks_args args = { - level, - buffer, - }; - if (ruby_thread_has_gvl_p()) { - call_log_callbacks((void *)&args); - } else { - rb_thread_call_with_gvl(call_log_callbacks, (void *)&args); - } -} - -/* - * call-seq: - * log_set ->(level, buffer, user_data) { ... }, user_data -> nil - */ -static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { - VALUE old_callback = rb_iv_get(self, "log_callback"); - if (!NIL_P(old_callback)) { - rb_undefine_finalizer(old_callback); - } - - rb_iv_set(self, "log_callback", log_callback); - rb_iv_set(self, "user_data", user_data); - - if (!NIL_P(log_callback)) { - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); - } - - if (NIL_P(log_callback)) { - whisper_log_set(NULL, NULL); - is_ruby_log_callback_present = false; - } else { - whisper_log_set(ruby_whisper_log_callback, NULL); - is_ruby_log_callback_present = true; - } - - return Qnil; -} - void Init_whisper() { id_to_s = rb_intern("to_s"); id_call = rb_intern("call"); @@ -189,9 +133,19 @@ void Init_whisper() { id_coreml_compiled_models = rb_intern("coreml_compiled_models"); id_cache = rb_intern("cache"); id_n_processors = rb_intern("n_processors"); + id_extended = rb_intern("extended"); + id_start_log_callback_thread = rb_intern("start_log_callback_thread"); + id_log_callback_thread = rb_intern("@log_callback_thread"); + id_alive_p = rb_intern("alive?"); + id_join = rb_intern("join"); mWhisper = rb_define_module("Whisper"); + rb_require("whisper/log_settable"); + mLogSettable = rb_path2class("Whisper::LogSettable"); mVAD = rb_define_module_under(mWhisper, "VAD"); + rb_require("whisper/output"); + mOutputContext = rb_path2class("Whisper::Output::Context"); + mOutputSegment = rb_path2class("Whisper::Output::Segment"); rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version())); rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); @@ -222,8 +176,8 @@ void Init_whisper() { rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0); - rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); - rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); + + LOG_SETTABLE_INIT(whisper_log_queue, mWhisper) cContext = init_ruby_whisper_context(&mWhisper); init_ruby_whisper_context_params(&cContext); @@ -236,8 +190,10 @@ void Init_whisper() { init_ruby_whisper_vad_segment(&mVAD); init_ruby_whisper_vad_segments(&mVAD); init_ruby_whisper_vad_context(&mVAD); + init_ruby_whisper_parakeet(&mWhisper); - rb_require("whisper/context"); - rb_require("whisper/segment"); rb_require("whisper/model/uri"); + + rb_include_module(cContext, mOutputContext); + rb_include_module(cSegment, mOutputSegment); } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index ba4d8b6fbcc..10e90674953 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -5,8 +5,12 @@ #include <ruby/version.h> #include <ruby/util.h> #include <ruby/thread.h> +#include <ruby/thread_native.h> +#include <ruby/atomic.h> #include <ruby/memory_view.h> #include "whisper.h" +#include "parakeet.h" +#include "ruby_whisper_log_settable.h" #if RUBY_API_VERSION_MAJOR < 4 // Exists but not declared as public API @@ -20,13 +24,28 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; -typedef struct { - VALUE *context; - VALUE user_data; - VALUE callback; - VALUE callbacks; - bool is_interrupted; -} ruby_whisper_abort_callback_container; +typedef struct ruby_whisper_abort_callback_user_data { + volatile rb_atomic_t is_interrupted; + ruby_whisper_callback_container *callback_container; +} ruby_whisper_abort_callback_user_data; + +typedef struct ruby_whisper_log { + enum ggml_log_level level; + char *text; + size_t length; + size_t capacity; +} ruby_whisper_log; + +typedef struct ruby_whisper_log_queue { + rb_nativethread_lock_t lock; + rb_nativethread_cond_t cond; + bool is_open; + + size_t head; + size_t tail; + size_t size; + ruby_whisper_log *logs; +} ruby_whisper_log_queue; typedef struct { struct whisper_context *context; @@ -42,7 +61,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; @@ -84,6 +103,63 @@ typedef struct parsed_samples_t { bool memview_exported; } parsed_samples_t; +typedef struct { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} ruby_whisper_full_args; + +typedef struct ruby_whisper_full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} ruby_whisper_full_parallel_args; + +typedef struct { + struct parakeet_full_params params; + ruby_whisper_callback_container *new_segment_callback_container; + ruby_whisper_callback_container *new_token_callback_container; + ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *encoder_begin_callback_container; + ruby_whisper_callback_container *abort_callback_container; +} ruby_whisper_parakeet_params; + +typedef struct { + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_params; + +typedef struct { + struct parakeet_context *context; +} ruby_whisper_parakeet_context; + +typedef struct { + VALUE context; + int index; +} ruby_whisper_parakeet_segment; + +typedef struct { + parakeet_token_data *token_data; + VALUE text; +} ruby_whisper_parakeet_token; + +typedef struct { + VALUE context; +} ruby_whisper_parakeet_model; + +extern ID id_extended; +extern ID id_log_callback_thread; +extern ID id_start_log_callback_thread; +extern ID id_alive_p; +extern ID id_join; +extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text); +extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue); + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -120,4 +196,47 @@ typedef struct parsed_samples_t { } \ } while (0) +#define GetParakeetContextParams(obj, rwpcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \ +} while (0) + +#define GetParakeetContext(obj, rwpc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \ + if ((rwpc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetParams(obj, rwpp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \ + if (!(rwpp)->new_segment_callback_container || \ + !(rwpp)->new_token_callback_container || \ + !(rwpp)->progress_callback_container || \ + !(rwpp)->encoder_begin_callback_container || \ + !(rwpp)->abort_callback_container) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetSegment(obj, rwps) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \ + if (!(rwps)->context) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetToken(obj, rwpt) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \ + if (!(rwpt)->token_data) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetModel(obj, rwpm) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \ + if (NIL_P((rwpm)->context)) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + #endif diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 26058fc07e6..9e5fc33e726 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data); ID transcribe_option_names[1]; @@ -38,21 +38,6 @@ typedef struct fill_samples_args { int n_samples; } fill_samples_args; -typedef struct full_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; -} full_args; - -typedef struct full_parallel_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; - int n_processors; -} full_parallel_args; - typedef struct full_without_gvl_args { struct whisper_context *context; struct whisper_full_params *params; @@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args { } full_parallel_without_gvl_args; typedef struct full_ubf_args { - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_abort_callback_user_data *abort_callback_user_data; } full_ubf_args; static void @@ -379,7 +364,7 @@ fill_samples(VALUE rb_args) return Qnil; } -struct parsed_samples_t +parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples) { bool memview_available = rb_memory_view_available_p(*samples); @@ -480,20 +465,24 @@ full_ubf(void *rb_args) { full_ubf_args *args = (full_ubf_args *)rb_args; - args->abort_callback_container->is_interrupted = true; + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); } -static VALUE +VALUE full_body(VALUE rb_args) { - full_args *args = (full_args *)rb_args; + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, 1); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, 1, &abort_callback_user_data); struct full_without_gvl_args full_without_gvl_args = { rw->context, @@ -503,7 +492,7 @@ full_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_without_gvl_args.result); @@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) VALUE n_samples = argc == 2 ? Qnil : argv[2]; struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - full_args args = { + ruby_whisper_full_args args = { &self, &argv[0], parsed.samples, @@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args) return NULL; } -static VALUE +VALUE full_parallel_body(VALUE rb_args) { - full_parallel_args *args = (full_parallel_args *)rb_args; + ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, args->n_processors); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data); struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { rw->context, @@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_parallel_without_gvl_args.result); @@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) break; } struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - const full_parallel_args args = { + const ruby_whisper_full_parallel_args args = { &self, &argv[0], parsed.samples, diff --git a/bindings/ruby/ext/ruby_whisper_log_queue.c b/bindings/ruby/ext/ruby_whisper_log_queue.c new file mode 100644 index 00000000000..6558a339c6f --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_queue.c @@ -0,0 +1,180 @@ +#include "ruby_whisper.h" + +#define LOG_QUEUE_CAPACITY 256 +#define LOG_DEFAULT_CAPACITY 1024 + +void +ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_initialize(&log_queue->lock); + rb_native_cond_initialize(&log_queue->cond); + log_queue->head = 0; + log_queue->tail = 0; + log_queue->size = 0; + log_queue->is_open = true; + log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY); + for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) { + // we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL + // this doesn't be freed because log queue lives until the end of process + char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY); + if (!slot) { + rb_raise(rb_eRuntimeError, "Could not allocate memory for log text"); + } + ruby_whisper_log log = { + 0, + slot, + 0, + LOG_QUEUE_CAPACITY, + }; + log_queue->logs[i] = log; + } +} + +void +ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = true; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +void +ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = false; + rb_native_cond_broadcast(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static size_t +calc_enough_cap(size_t len) +{ + size_t quot = len / LOG_DEFAULT_CAPACITY; + size_t rem = len % LOG_DEFAULT_CAPACITY; + + return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY; +} + +void +ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + if (!log_queue->is_open) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + + size_t len = strlen(text); + ruby_whisper_log *log = &log_queue->logs[log_queue->head]; + if (len > log->capacity) { + size_t new_cap = calc_enough_cap(len); + // we cannot call Ruby API like REALLOC_N because this function is called without GVL + char *slot = realloc(log->text, new_cap); + if (!slot) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + log->text = slot; + log->capacity = new_cap; + } + // we cannot call Ruby API like MEMCPY because this function is called without GVL + memcpy(log->text, text, sizeof(char) * len); + log->length = len; + log->level = level; + log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY; + bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY; + log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1; + if (is_full) { + log_queue->tail = log_queue->head; + } + + rb_native_cond_signal(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static void* +ruby_whisper_log_queue_wait(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_wait(&log_queue->cond, &log_queue->lock); + rb_nativethread_lock_unlock(&log_queue->lock); + + return NULL; +} + +static void +ruby_whisper_log_queue_wait_ubf(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_broadcast(&log_queue->cond); +} + +typedef struct { + enum ggml_log_level level; + size_t length; + char *text; +} log_snapshot; + +VALUE +ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue) +{ + log_snapshot logs[LOG_QUEUE_CAPACITY]; + + rb_nativethread_lock_lock(&log_queue->lock); + + while (log_queue->size == 0 && log_queue->is_open) { + rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue); + rb_nativethread_lock_lock(&log_queue->lock); + } + + if (log_queue->size == 0 && !log_queue->is_open) { + rb_native_cond_broadcast(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); + return Qnil; + } + + size_t size = log_queue->size; + ruby_whisper_log *log; + size_t i; + for (i = 0; i < size; i++) { + log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY]; + logs[i].level = log->level; + logs[i].length = log->length; + char *text = malloc(log->length); + if (!text) { + logs[i].text = NULL; + continue; + } + logs[i].text = text; + memcpy(logs[i].text, log->text, log->length); + } + log_queue->size = 0; + log_queue->tail = log_queue->head; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); + + VALUE rb_logs = rb_ary_new2(size); + VALUE rb_text; + for (i = 0; i < size; i++) { + if (!logs[i].text) { + continue; + } + rb_text = rb_str_new(logs[i].text, logs[i].length); + free(logs[i].text); + rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text)); + } + + return rb_logs; +} diff --git a/bindings/ruby/ext/ruby_whisper_log_settable.h b/bindings/ruby/ext/ruby_whisper_log_settable.h new file mode 100644 index 00000000000..b98fbac826b --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_settable.h @@ -0,0 +1,47 @@ +#ifndef RUBY_WHISPER_LOG_SETTABLE_H +#define RUBY_WHISPER_LOG_SETTABLE_H + +#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \ + static VALUE \ + ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \ + { \ + return ruby_whisper_log_queue_drain(&log_queue); \ + } \ + static void \ + ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \ + { \ + ruby_whisper_log_queue_enqueue(&log_queue, level, text); \ + } \ + static VALUE \ + ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \ + { \ + rb_iv_set(self, "@log_callback", log_callback); \ + rb_iv_set(self, "@log_callback_user_data", user_data); \ + if (NIL_P(log_callback)) { \ + log_set(NULL, NULL); \ + } else { \ + ruby_whisper_log_queue_open(&log_queue); \ + rb_funcall((mod), id_start_log_callback_thread, 0); \ + log_set(ruby_whisper_##log_queue##_log_callback, NULL); \ + } \ + return Qnil; \ + } \ + static void \ + ruby_whisper_##log_queue##_end_proc(VALUE args) \ + { \ + ruby_whisper_log_queue_close(&log_queue); \ + VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \ + if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \ + rb_funcall(log_callback_thread, id_join, 0); \ + } \ + } + +#define LOG_SETTABLE_INIT(log_queue, mod) \ + ruby_whisper_log_queue_initialize(&log_queue); \ + rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \ + rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \ + rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \ + rb_extend_object(mod, mLogSettable); \ + rb_funcall(mLogSettable, id_extended, 1, mod); + +#endif diff --git a/bindings/ruby/ext/ruby_whisper_parakeet.c b/bindings/ruby/ext/ruby_whisper_parakeet.c new file mode 100644 index 00000000000..d69369401d0 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet.c @@ -0,0 +1,49 @@ +#include "ruby_whisper.h" +#include <stdio.h> +#include <unistd.h> + +extern VALUE mParakeet; +extern VALUE mLogSettable; +extern VALUE cParakeetContext; +extern VALUE cParakeetSegment; +extern VALUE mOutputContext; +extern VALUE mOutputSegment; + +extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet); +extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext); +extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet); + +static ruby_whisper_log_queue parakeet_log_queue; + +LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set) + +static VALUE +ruby_whisper_parakeet_s_system_info_str(VALUE self) +{ + return rb_str_new2(parakeet_print_system_info()); +} + +void +init_ruby_whisper_parakeet(VALUE *mWhisper) +{ + mParakeet = rb_define_module_under(*mWhisper, "Parakeet"); + + rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version())); + + LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet) + + rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0); + + init_ruby_whisper_parakeet_params(&mParakeet); + init_ruby_whisper_parakeet_token(&mParakeet); + init_ruby_whisper_parakeet_segment(&mParakeet); + cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet); + init_ruby_whisper_parakeet_context_params(&cParakeetContext); + init_ruby_whisper_parakeet_model(&mParakeet); + + rb_include_module(cParakeetContext, mOutputContext); + rb_include_module(cParakeetSegment, mOutputSegment); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c new file mode 100644 index 00000000000..b4a2fc5c4b7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -0,0 +1,304 @@ +#include "ruby_whisper.h" + +#define ITERATE_SEGMENT_ATTRS(ITERATOR) \ + ITERATOR(get_segment_t0, LONG) \ + ITERATOR(get_segment_t1, LONG) \ + ITERATOR(get_segment_text, STRING) \ + ITERATOR(n_tokens, INT) + +#define ITERATE_TOKEN_ATTRS(ITERATOR) \ + ITERATOR(get_token_text, STRING) \ + ITERATOR(get_token_id, INT) \ + ITERATOR(get_token_p, FLOAT) + +#define VAL_FROM_LONG(v) LONG2NUM(v) +#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v) +#define VAL_FROM_INT(v) INT2NUM(v) +#define VAL_FROM_FLOAT(v) DBL2NUM(v) +#define READER(type) VAL_FROM_##type + +extern ID id_to_s; +extern ID id___method__; +extern ID id_to_enum; +extern ID id_new; + +extern VALUE cParakeetContext; +extern VALUE eError; + +extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE rb_parsed_args); +extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data); +extern rb_data_type_t ruby_whisper_parakeet_params_type; +extern rb_data_type_t ruby_whisper_parakeet_context_params_type; +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); +extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context); + +static void +ruby_whisper_parakeet_context_free(void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (rwpc->context) { + parakeet_free(rwpc->context); + rwpc->context = NULL; + } + xfree(rwpc); +} + +static size_t +ruby_whisper_parakeet_context_memsize(const void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (!rwpc) { + return 0; + } + size_t size = sizeof(*rwpc); + return size; +} + +const rb_data_type_t ruby_whisper_parakeet_context_type = { + "ruby_whisper_parakeet_context", + {0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_context_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context *rwpc; + + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + rwpc->context = NULL; + + return obj; +} + +typedef struct { + struct parakeet_context **context; + char *model_path; + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_init_args; + +static void* +ruby_whisper_parakeet_context_init_without_gvl(void *args) +{ + ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args; + *init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params); + return NULL; +} + +static VALUE +ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + VALUE model_path; + VALUE context_params; + struct parakeet_context_params params; + + rb_scan_args(argc, argv, "11", &model_path, &context_params); + TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + + model_path = ruby_whisper_normalize_model_path(model_path); + if (!rb_respond_to(model_path, id_to_s)) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context"); + } + if (NIL_P(context_params)) { + params = parakeet_context_default_params(); + } else { + ruby_whisper_parakeet_context_params *rwpcp; + GetParakeetContextParams(context_params, rwpcp); + params = rwpcp->params; + } + ruby_whisper_parakeet_context_init_args init_args = { + &rwpc->context, + StringValueCStr(model_path), + params, + }; + rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL); + if (rwpc->context == NULL) { + rb_raise(rb_eRuntimeError, "Failed to load model"); + } + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_context_full_n_segments(VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + return INT2NUM(parakeet_full_n_segments(rwpc->context)); +} + +#define DEF_SEGMENT_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \ + } + +ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR) + +#define DEF_TOKEN_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \ + } + +ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR) + +static VALUE +ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token)); + + return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data); +} + +static VALUE +ruby_whisper_parakeet_context_each_segment(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + const int n_segments = parakeet_full_n_segments(rwpc->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(ruby_whisper_parakeet_segment_init(self, i)); + } + + return self; +} + +typedef struct { + struct parakeet_context *context; + struct parakeet_full_params *params; + float *samples; + int n_samples; + int result; +} parakeet_full_without_gvl_args; + +static void* +parakeet_full_without_gvl(void *rb_args) +{ + parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args; + args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_user_data *abort_callback_user_data; +} parakeet_full_ubf_args; + +static void +parakeet_full_ubf(void *rb_args) +{ + parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args; + + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); +} + +VALUE +ruby_whisper_parakeet_context_full_body(VALUE rb_args) +{ + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(*args->context, rwpc); + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(*args->params, rwpp); + + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data); + + parakeet_full_without_gvl_args full_without_gvl_args = { + rwpc->context, + &rwpp->params, + args->samples, + args->n_samples, + 0 + }; + parakeet_full_ubf_args full_ubf_args = { + &abort_callback_user_data, + }; + rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args); + + return INT2NUM(full_without_gvl_args.result); +} + +static VALUE +ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + ruby_whisper_full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +static VALUE +ruby_whisper_parakeet_context_get_model(VALUE self) +{ + return ruby_whisper_parakeet_model_s_new(self); +} + +VALUE +init_ruby_whisper_parakeet_context(VALUE *mParakeet) +{ + cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject); + + rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate); + + rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1); + rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2); + rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0); + rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2); + rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0); + rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0); + rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1); + +#define REGISTER_SEGMENT_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1); + + ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR) + +#define REGISTER_TOKEN_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2); + + ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR) + + return cParakeetContext; +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c new file mode 100644 index 00000000000..38bd6d57ce1 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c @@ -0,0 +1,117 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(use_gpu, BOOL) \ + ITERATOR(gpu_device, INT) + +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_INT(v) (NUM2INT(v)) +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type + +#define DEF_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + return READER(type)(rwpcp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + rwpcp->params.name = WRITER(type)(val); \ + return val; \ + } + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS +}; + +extern VALUE cParakeetContextParams; + +typedef VALUE (*param_writer_t)(VALUE, VALUE); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; + +static size_t +ruby_whisper_parakeet_context_params_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_context_params); +} + +const rb_data_type_t ruby_whisper_parakeet_context_params_type = { + "ruby_whisper_parakeet_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context_params *rwpcp; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); +} + +static VALUE +ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_context_params *rwpcp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); + rwpcp->params = parakeet_context_default_params(); + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values); + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext) +{ + cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject); + + rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate); + + rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1); + + int i = 0; +#define REGISTER_ATTR(name, type) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \ + rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \ + rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \ + i++; + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_model.c b/bindings/ruby/ext/ruby_whisper_parakeet_model.c new file mode 100644 index 00000000000..dce43c688e7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_model.c @@ -0,0 +1,84 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(n_vocab) \ + ITERATOR(n_audio_ctx) \ + ITERATOR(n_audio_state) \ + ITERATOR(n_audio_head) \ + ITERATOR(n_audio_layer) \ + ITERATOR(n_mels) \ + ITERATOR(ftype) + +extern rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE cParakeetModel; + +static void +ruby_whisper_parakeet_model_mark(void *p) +{ + ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p; + if (!NIL_P(rwpm->context)) { + rb_gc_mark(rwpm->context); + } +} + +static size_t +ruby_whisper_parakeet_model_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_model); +} + +static const rb_data_type_t ruby_whisper_parakeet_model_type = { + "ruby_whisper_parakeet_model", + {ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_model_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_model *rwpm; + VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = Qnil; + + return model; +} + +VALUE +ruby_whisper_parakeet_model_s_new(VALUE context) +{ + const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel); + ruby_whisper_parakeet_model *rwpm; + TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = context; + return model; +} + +#define DEF_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_model_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_model *rwpm; \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetModel(self, rwpm); \ + GetParakeetContext(rwpm->context, rwpc); \ + return INT2NUM(parakeet_model_##name(rwpc->context)); \ + } + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_model(VALUE *mParakeet) +{ + cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject); + + rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate); + +#define REGISTER_ATTR(name) \ + rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c new file mode 100644 index 00000000000..076e2a0cdfb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -0,0 +1,548 @@ +#include "ruby_whisper.h" + +#define ITERATE_PARAMS(ITERATOR) \ + ITERATOR(n_threads, INT) \ + ITERATOR(offset_ms, INT) \ + ITERATOR(duration_ms, INT) \ + ITERATOR(no_context, BOOL) \ + ITERATOR(audio_ctx, INT) + +#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \ + ITERATOR(new_segment, DATA) \ + ITERATOR(new_token, DATA) \ + ITERATOR(progress, DATA) \ + ITERATOR(encoder_begin, DATA) + +#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback) +#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR) + +#define ITERATE_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATOR(abort_callback) + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data, + ITERATE_PARAMS(DEF_IDX) + ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK) + ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA) + + RUBY_WHISPER_PARAKEET_NUM_PARAMS +}; + +#define VAL_TO_INT(v) (NUM2INT(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse) + +extern VALUE cParakeetParams; +extern ID id_call; + +extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc); +extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); +extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; +typedef VALUE (*param_writer_t)(VALUE, VALUE); +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int n_new; +} call_parakeet_new_segment_callbacks_args; + +static void* +call_parakeet_new_segment_callbacks(void *v_args) +{ + call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + const int n_segments = parakeet_full_n_segments_from_state(args->state); + for (int i = args->n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment); + for (int j = 0; j < n_callbacks; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_segment_callbacks_args args = { + container, + state, + n_new, + }; + rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_context *context; + struct parakeet_state *state; + const parakeet_token_data *token_data; +} call_parakeet_new_token_callbacks_args; + +static void* +call_parakeet_new_token_callbacks(void *v_args) +{ + call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args; + VALUE token = Qnil; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + if (NIL_P(token)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + } + for (int i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, token); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_token_callbacks_args args = { + container, + context, + state, + token_data, + }; + rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int progress; +} call_parakeet_progress_callbacks_args; + +static void* +call_parakeet_progress_callback(void *v_args) +{ + call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, INT2NUM(args->progress)); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_progress_callbacks_args args = { + container, + state, + progress, + }; + rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + bool is_continued; +} call_parakeet_encoder_begin_callbacks_args; + +static void* +call_parakeet_encoder_begin_callbacks(void *v_args) +{ + call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_parakeet_encoder_begin_callbacks_args args = { + container, + state, + true, + }; + rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_callback_container *container; + bool is_interrupted; +} call_parakeet_abort_callbacks_args; + +static void* +call_parakeet_abort_callbacks(void *v_args) +{ + call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + VALUE cb; + for (long i = 0; i < n_callbacks; i++) { + cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_abort_callback(void *user_data) +{ + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; + + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { + return true; + } + + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { + return false; + } + + call_parakeet_abort_callbacks_args args = { + data->callback_container, + false, + }; + rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args); + + return args.is_interrupted; +} + +#define CALLBACK_CONTAINER_NAME(name) name ## _container + +void +ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) +{ +#define PARAM_NAME(name) name +#define USER_DATA_NAME(name) name##_user_data +#define REGISTER_CALLBACK(name) \ + if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \ + rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \ + rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \ + rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \ + } + + ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK) + + if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) { + abort_callback_user_data->callback_container = rwpp->abort_callback_container; + } + rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback; + rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data; +} + +static void +ruby_whisper_parakeet_params_mark(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define MARK_CONTAINER(name) \ + if (rwpp->name##_container) { \ + ruby_whisper_callback_container_mark(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(MARK_CONTAINER) +} + +static void +ruby_whisper_parakeet_params_free(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define FREE_CONTAINER(name) \ + if (rwpp->name##_container) { \ + xfree(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(FREE_CONTAINER) + + xfree(rwpp); +} + +static size_t +ruby_whisper_parakeet_params_memsize(const void *p) +{ + const struct ruby_whisper_parakeet_params *params = p; + if (!params) { + return 0; + } + return sizeof(ruby_whisper_parakeet_params); +} + +const rb_data_type_t ruby_whisper_parakeet_params_type = { + "ruby_whisper_parakeet_params", + {ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,}, + 0, 0, + 0 +}; + +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type +#define DEF_PARAM_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return READER(type)(rwpp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->params.name = WRITER(type)(val); \ + return val; \ + } + +#define DEF_CALLBACK_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \ + return val; \ + } + +#define DEF_USER_DATA_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \ + return val; \ + } + +#define DEF_HOOK(name, data) \ + static VALUE \ + ruby_whisper_parakeet_params_on_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + const VALUE blk = rb_block_proc(); \ + if (NIL_P(rwpp->name##_callback_container->callbacks)) { \ + rwpp->name##_callback_container->callbacks = rb_ary_new(); \ + } \ + rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \ + return Qnil; \ + } + +ITERATE_PARAMS(DEF_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR) +ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _) + +static VALUE +ruby_whisper_parakeet_params_abort_on(VALUE self) +{ + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(self, rwpp); + const VALUE blk = rb_block_proc(); + if (NIL_P(rwpp->abort_callback_container->callbacks)) { + rwpp->abort_callback_container->callbacks = rb_ary_new(); + } + rb_ary_push(rwpp->abort_callback_container->callbacks, blk); + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_params *rwpp; + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + return obj; +} + +static VALUE +ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_params *rwpp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + +#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate(); + + ITERATE_CALLBACK_PARAMS(INIT_CONTAINER) + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values); + + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +void +init_ruby_whisper_parakeet_params(VALUE *mParakeet) +{ + cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject); + rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate); + + rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1); + + int i = 0; +#define REGISTER_PARAM(name) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_params_set_##name; \ + rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \ + rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \ + i++; + +#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name) +#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name) +#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data) + + ITERATE_PARAMS(REGISTER_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR) + +#define REGISTER_HOOK(name, data) \ + rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0); + + ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _) + + rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_segment.c b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c new file mode 100644 index 00000000000..b1e81ba930c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c @@ -0,0 +1,157 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(start_time, t0, TIME) \ + ITERATOR(end_time, t1, TIME) \ + ITERATOR(text, text, STRING) + +enum { +#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, +}; + +#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10)) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) +#define READER(type) VAL_FROM_##type +#define DEF_ATTR(rb_name, c_name, type) \ + static VALUE \ + ruby_whisper_parakeet_get_##rb_name(VALUE self) \ + { \ + ruby_whisper_parakeet_segment *rwps; \ + GetParakeetSegment(self, rwps); \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(rwps->context, rwpc); \ + return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \ + } + +extern ID id___method__; +extern ID id_to_enum; +extern VALUE cParakeetSegment; +extern VALUE sym_start_time; +extern VALUE sym_end_time; +extern VALUE sym_text; +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token); + +static void +rb_whisper_parakeet_segment_mark(void *p) +{ + ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p; + rb_gc_mark(rwps->context); +} + +static size_t +ruby_whisper_parakeet_segment_memsize(const void *p) +{ + const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p; + if (!rwps) { + return 0; + } + return sizeof(*rwps); +} + +static const rb_data_type_t ruby_whisper_parakeet_segment_type = { + "ruby_whisper_parakeet_segment", + {rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_segment_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_segment *rwps; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); +} + +VALUE +ruby_whisper_parakeet_segment_init(VALUE context, int index) +{ + ruby_whisper_parakeet_segment *rwps; + + const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment); + TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); + rwps->context = context; + rwps->index = index; + + return segment; +} + +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_segment_each_token(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index); + for (int i = 0; i < n_tokens; i++) { + rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i)); + } + + return self; +} + +static VALUE +ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + VALUE hash = rb_hash_new(); + long n_keys; + if (NIL_P(keys)) { + keys = rb_ary_new3( + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, + sym_start_time, + sym_end_time, + sym_text + ); + n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) { + return hash; + } + } + for (int i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, c_name, type) \ + if (key == sym_##rb_name) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \ + } + + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_segment(VALUE *mParakeet) +{ + cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject); + + rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate); + +#define REGISTER_ATTR(rb_name, c_name, type) \ + rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0); + rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_token.c b/bindings/ruby/ext/ruby_whisper_parakeet_token.c new file mode 100644 index 00000000000..a00b7ae1cbb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_token.c @@ -0,0 +1,188 @@ +#include "ruby_whisper.h" + +#define ITERATE_MEMBERS(ITERATOR) \ + ITERATOR(id, id, id, id, INT) \ + ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \ + ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \ + ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \ + ITERATOR(probability, probability, p, p, FLOAT) \ + ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \ + ITERATOR(start_time, start_time, start_time, t0, TIME) \ + ITERATOR(end_time, end_time, end_time, t1, TIME) \ + ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL) + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(text, text, text, text, STRING) + +enum { +#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name, + + ITERATE_MEMBERS(DEF_IDX) + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, +}; + +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_FROM_FLOAT(v) (DBL2NUM(v)) +#define VAL_FROM_TIME(v) (LONG2NUM(v * 10)) +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) + +#define READER(type) VAL_FROM_##type +#define MEMBER_NAME(name) name +#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \ + } + +#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return rwpt->p_name; \ + } + +VALUE cParakeetToken; + +#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key; + +ITERATE_MEMBERS(DEC_ATTR_SYMS) +ITERATE_ATTRS(DEC_ATTR_SYMS) + +static void +ruby_whisper_parakeet_token_mark(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + rb_gc_mark(rwpt->text); +} + +static void +ruby_whisper_parakeet_token_free(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (rwpt->token_data) { + xfree(rwpt->token_data); + rwpt->token_data = NULL; + } + xfree(rwpt); +} + +static size_t +ruby_whisper_parakeet_token_memsize(const void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (!rwpt) { + return 0; + } + size_t size = sizeof(*rwpt); + if (rwpt->token_data) { + size += sizeof(*rwpt->token_data); + } + + return size; +} + +static const rb_data_type_t ruby_whisper_parakeet_token_type = { + "ruby_whisper_parakeet_token", + {ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_token_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_token *rwpt; + VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = NULL; + rwpt->text = Qnil; + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data) +{ + const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken); + ruby_whisper_parakeet_token *rwpt; + TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = ALLOC(parakeet_token_data); + *rwpt->token_data = *token_data; + rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id)); + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token) +{ + parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token); + return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data); +} + +ITERATE_MEMBERS(DEF_MEMBER_ATTR) +// Define #text using parakeet_token_to_str or parakeet_token_to_text +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_token *rwpt; + GetParakeetToken(self, rwpt); + + VALUE hash = rb_hash_new(); + long n_keys = 0; + + if (NIL_P(keys)) { + VALUE attrs[] = { +#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key, + + ITERATE_MEMBERS(LIST_SYMS) + ITERATE_ATTRS(LIST_SYMS) + }; + keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs); + n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) { + return hash; + } + } + for (long i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \ + if (key == sym_##s_key) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \ + } + + ITERATE_MEMBERS(CHECK_AND_SET_KEY) + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_token(VALUE *mParakeet) +{ + cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject); + rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate); + +#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \ + sym_##s_key = ID2SYM(rb_intern(#s_key)); \ + rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0); + + ITERATE_MEMBERS(REGISTER_ATTR) + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp new file mode 100644 index 00000000000..c4deccce84a --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -0,0 +1,58 @@ +#include "ruby_whisper.h" +#include "common-whisper.h" +#include <string> +#include <vector> + +#ifdef __cplusplus +extern "C" { +#endif + +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern const rb_data_type_t ruby_whisper_parakeet_params_type; + +extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args); + +extern ID id_to_path; +extern ID id_new; + +extern VALUE eError; + +VALUE +ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) +{ + if (rb_respond_to(audio_path, id_to_path)) { + audio_path = rb_funcall(audio_path, id_to_path, 0); + } + + std::string fname = StringValueCStr(audio_path); + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + + if (!read_audio_data(fname, pcmf32, pcmf32s, false)) { + rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str()); + return Qnil; + } + + ruby_whisper_parakeet_context *rwpc; + ruby_whisper_parakeet_params *rwpp; + GetParakeetContext(self, rwpc); + GetParakeetParams(params, rwpp); + + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 2aae7c12d19..f38e9bde3ea 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -76,8 +76,8 @@ static ID id_vad; static ID id_vad_model_path; static ID id_vad_params; -static void -rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) +void +ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc) { if (rwc == NULL) return; @@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) rb_gc_mark(rwc->callbacks); } -static ruby_whisper_callback_container* -rb_whisper_callback_container_allocate() { +ruby_whisper_callback_container* +ruby_whisper_callback_container_allocate() { ruby_whisper_callback_container *container; container = ALLOC(ruby_whisper_callback_container); container->context = NULL; @@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() { return container; } -static void -rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) -{ - if (rwc == NULL) return; - - rb_gc_mark(rwc->user_data); - rb_gc_mark(rwc->callback); - rb_gc_mark(rwc->callbacks); -} - -static ruby_whisper_abort_callback_container* -rb_whisper_abort_callback_container_allocate() { - ruby_whisper_abort_callback_container *container; - container = ALLOC(ruby_whisper_abort_callback_container); - container->context = NULL; - container->user_data = Qnil; - container->callback = Qnil; - container->callbacks = Qnil; - container->is_interrupted = false; - return container; -} - -static bool +bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { return !NIL_P(container->callback) || !NIL_P(container->callbacks); } -static bool -ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { - return !NIL_P(container->callback) || !NIL_P(container->callbacks); -} - typedef struct { const ruby_whisper_callback_container *container; struct whisper_state *state; @@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s } typedef struct { - const ruby_whisper_abort_callback_container *container; - struct whisper_state *state; + const ruby_whisper_callback_container *container; bool is_interrupted; } call_abort_callbacks_args; static void* call_abort_callbacks(void *v_args) { call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; - const ruby_whisper_abort_callback_container *container = args->container; - - if (container->is_interrupted) { - args->is_interrupted = true; - return NULL; - } + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) { if (NIL_P(container->callbacks)) { return NULL; } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (0 == n_callbacks) { return NULL; } - for (int j = 0; j < callbacks_len; j++) { + for (int j = 0; j < n_callbacks; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + VALUE result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) { } static bool abort_callback(void * user_data) { - const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; - if (container->is_interrupted) { + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { return true; } - if (!ruby_whisper_abort_callback_container_is_present(container)) { + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { return false; } call_abort_callbacks_args args = { - container, - NULL, + data->callback_container, false }; rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); @@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors) return; } - if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { - rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); - } - - if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { - rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); - } + // new_segment_callback is called only after multiple threads are joined + // progress_callback is not called when parallel if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { + if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } - - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (!NIL_P(log_callback)) { - rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription"); - } } -static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { +static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; @@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } + abort_callback_user_data->callback_container = rwp->abort_callback_container; rwp->abort_callback_container->context = context; rwp->params.abort_callback = abort_callback; - rwp->abort_callback_container->is_interrupted = false; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; + rwp->params.abort_callback_user_data = (void *)abort_callback_user_data; } static void set_vad_params(ruby_whisper_params *rwp) @@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp) rwp->params.vad_params = rwvp->params; } -/* - TODO: Set abort callback to trap SIGINT and SIGTERM -*/ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { check_thread_safety(rwp, n_processors); - register_callbacks(rwp, context); + register_callbacks(rwp, context, abort_callback_user_data); set_vad_params(rwp); } @@ -421,10 +376,10 @@ void rb_whisper_params_mark(void *p) { ruby_whisper_params *rwp = (ruby_whisper_params *)p; - rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); - rb_whisper_callbcack_container_mark(rwp->progress_callback_container); - rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); + ruby_whisper_callback_container_mark(rwp->new_segment_callback_container); + ruby_whisper_callback_container_mark(rwp->progress_callback_container); + ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container); + ruby_whisper_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass) } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); - rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); - rwp->progress_callback_container = rb_whisper_callback_container_allocate(); - rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); + rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); + rwp->progress_callback_container = ruby_whisper_callback_container_allocate(); + rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); + rwp->abort_callback_container = ruby_whisper_callback_container_allocate(); return obj; } diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index ee0d66c4cc8..cf0372797d3 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -4,12 +4,12 @@ extern ID id___method__; extern ID id_to_enum; -static VALUE sym_start_time; -static VALUE sym_end_time; -static VALUE sym_text; -static VALUE sym_no_speech_prob; -static VALUE sym_speaker_turn_next; -static VALUE sym_n_tokens; +VALUE sym_start_time; +VALUE sym_end_time; +VALUE sym_text; +VALUE sym_no_speech_prob; +VALUE sym_speaker_turn_next; +VALUE sym_n_tokens; extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 37656af1c44..73f606ca476 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,6 +16,8 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern VALUE full_body(VALUE rb_args); +extern VALUE full_parallel_body(VALUE rb_args); typedef struct{ struct whisper_context *context; @@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args) return NULL; } -typedef struct { - ruby_whisper_abort_callback_container *abort_callback_container; -} transcribe_ubf_args; - -static void -transcribe_ubf(void *rb_args) -{ - transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; - - args->abort_callback_container->is_interrupted = true; -} - /* * transcribe a single file * can emit to a block results @@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return self; } - // Commented out because it is work in progress - // { - // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - // bool is_aborted = *(bool*)user_data; - // return !is_aborted; - // }; - // rwp->params.encoder_begin_callback_user_data = &is_aborted; - // } - - prepare_transcription(rwp, &self, n_processors); - - transcribe_without_gvl_args args = { - rw->context, - &rwp->params, - pcmf32.data(), - pcmf32.size(), - n_processors, - 0, - }; - transcribe_ubf_args ubf_args = { - rwp->abort_callback_container, - }; - rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); - if (args.result != 0) { + + VALUE rb_result; + if (n_processors == 1) { + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + rb_result = full_body((VALUE)&args); + } else { + ruby_whisper_full_parallel_args parallel_args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + n_processors, + }; + rb_result = full_parallel_body((VALUE)¶llel_args); + } + const int result = NUM2INT(rb_result); + if (result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/lib/whisper/context.rb b/bindings/ruby/lib/whisper/context.rb deleted file mode 100644 index c3a134b773d..00000000000 --- a/bindings/ruby/lib/whisper/context.rb +++ /dev/null @@ -1,15 +0,0 @@ -module Whisper - class Context - def to_srt - each_segment.with_index.reduce("") {|srt, (segment, index)| - srt << "#{index + 1}\n#{segment.to_srt_cue}\n" - } - end - - def to_webvtt - each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| - webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" - } - end - end -end diff --git a/bindings/ruby/lib/whisper/log_settable.rb b/bindings/ruby/lib/whisper/log_settable.rb new file mode 100644 index 00000000000..2f8218d26ee --- /dev/null +++ b/bindings/ruby/lib/whisper/log_settable.rb @@ -0,0 +1,36 @@ +require "mutex_m" + +module Whisper + module LogSettable + class << self + def extended(base) + base.extend Mutex_m + end + end + + private + + def start_log_callback_thread + return if @log_callback_thread&.alive? + + @log_callback_thread = Thread.new { + begin + while logs = drain_logs + begin + callback, user_data = synchronize {[@log_callback, @log_callback_user_data]} + next if callback.nil? + + logs.each do |(level, text)| + callback.call level, text, user_data + end + rescue => err + $stderr.puts err + end + end + rescue => err + $stderr.puts err + end + } + end + end +end diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 8eb57e5e8cf..ef92eb901c4 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -41,6 +41,8 @@ def base_cache_dir def cache path = cache_path + return path if cache_path.exist? + headers = {} headers["if-modified-since"] = path.mtime.httpdate if path.exist? request @uri, headers @@ -216,8 +218,18 @@ def escaping(path) @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin") end + %w[ + parakeet-tdt-0.6b-v3-f16 + parakeet-tdt-0.6b-v3-f32 + parakeet-tdt-0.6b-v3-q4_0 + parakeet-tdt-0.6b-v3-q4_k + parakeet-tdt-0.6b-v3-q8_0 + ].each do |name| + @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin") + end + @coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models| - next if name.end_with?("-tdrz") || name.start_with?("silero-") + next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-") if matched = name.match(/\A(?<name>.*)-q\d_\d\z/) name = matched[:name] diff --git a/bindings/ruby/lib/whisper/output.rb b/bindings/ruby/lib/whisper/output.rb new file mode 100644 index 00000000000..1781af17a33 --- /dev/null +++ b/bindings/ruby/lib/whisper/output.rb @@ -0,0 +1,74 @@ +module Whisper + module Output + module Context + def to_srt + each_segment.with_index.reduce("") {|srt, (segment, index)| + srt << "#{index + 1}\n#{segment.to_srt_cue}\n" + } + end + + def to_webvtt + each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| + webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" + } + end + end + + module Segment + SRT_ESCAPES = { + "&" => "&", + "<" => "<", + ">" => ">", + } + SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) + private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE + + def to_srt_cue + "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" + end + + def to_webvtt_cue + "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" + end + + private + + def time_to_a(time) + sec, decimal_part = time.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + [hour, min, sec, decimal_part] + end + + def srt_time(time) + "%02d:%02d:%02d,%03d" % time_to_a(time) + end + + def srt_start_time + srt_time(start_time) + end + + def srt_end_time + srt_time(end_time) + end + + def srt_text + text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) + end + + def webvtt_time(time) + "%02d:%02d:%02d.%03d" % time_to_a(time) + end + + def webvtt_start_time + webvtt_time(start_time) + end + + def webvtt_end_time + webvtt_time(end_time) + end + + alias webvtt_text srt_text + end + end +end diff --git a/bindings/ruby/lib/whisper/segment.rb b/bindings/ruby/lib/whisper/segment.rb deleted file mode 100644 index dc187dcac36..00000000000 --- a/bindings/ruby/lib/whisper/segment.rb +++ /dev/null @@ -1,58 +0,0 @@ -module Whisper - class Segment - SRT_ESCAPES = { - "&" => "&", - "<" => "<", - ">" => ">", - } - SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) - private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE - - def to_srt_cue - "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" - end - - def to_webvtt_cue - "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" - end - - private - - def time_to_a(time) - sec, decimal_part = time.divmod(1000) - min, sec = sec.divmod(60) - hour, min = min.divmod(60) - [hour, min, sec, decimal_part] - end - - def srt_time(time) - "%02d:%02d:%02d,%03d" % time_to_a(time) - end - - def srt_start_time - srt_time(start_time) - end - - def srt_end_time - srt_time(end_time) - end - - def srt_text - text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) - end - - def webvtt_time(time) - "%02d:%02d:%02d.%03d" % time_to_a(time) - end - - def webvtt_start_time - webvtt_time(start_time) - end - - def webvtt_end_time - webvtt_time(end_time) - end - - alias webvtt_text srt_text - end -end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index cbec4803820..c12e1fe55e5 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -40,7 +40,21 @@ module Whisper def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String + module Output + module Context + def to_srt: () -> String + def to_webvtt: () -> String + end + + module Segment + def to_srt_cue: () -> String + def to_webvtt_cue: () -> String + end + end + class Context + include Output::Context + def self.new: (String | path | ::URI::HTTP) -> instance # transcribe a single file @@ -139,17 +153,14 @@ module Whisper | (Whisper::Params, _Samples, ?Integer n_samples) -> self | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self - def to_srt: () -> String - def to_webvtt: () -> String - class Params def self.new: ( - use_gpu: boolish, - flash_attn: boolish, - gpu_device: Integer, - dtw_token_timestamps: boolish, - dtw_aheads_preset: Integer, - dtw_n_top: Integer | nil, + ?use_gpu: boolish, + ?flash_attn: boolish, + ?gpu_device: Integer, + ?dtw_token_timestamps: boolish, + ?dtw_aheads_preset: Integer, + ?dtw_n_top: Integer | nil, ) -> instance def use_gpu=: (boolish) -> boolish @@ -444,6 +455,9 @@ module Whisper def abort_on: { (Object user_data) -> boolish } -> void end + module LogSettable + end + class Model def self.pre_converted_models: () -> Hash[String, Model::URI] def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI] @@ -474,6 +488,8 @@ module Whisper end class Segment + include Output::Segment + type deconstructed_keys = { start_time: (Integer | nil), end_time: (Integer | nil), @@ -514,9 +530,6 @@ module Whisper # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] - def to_srt_cue: () -> String - def to_webvtt_cue: () -> String - # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # @@ -528,7 +541,7 @@ module Whisper def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys end - module Token + class Token type deconstructed_keys = { id: (Integer | nil), tid: (Integer | nil), @@ -598,6 +611,336 @@ module Whisper def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys end + module Parakeet + extend LogSettable + + VERSION: String + + # Control logging output. The default behavior is to print to stderr. + # + def self.log_set: (nil, Object? user_data) -> nil + | (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil + def self.system_info_str: () -> String + + class Context + include Output::Context + + # Load a Parakeet model from the given file path. + # + def self.new: (String | path | ::URI::HTTP, ?Params) -> instance + + # Transcribe a single audio file. + # + def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self + + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + # Not thread safe for the same context. + # + # The second argument `samples` must be an array of samples, respond to `:length`, + # or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # + def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self + + # Number of generated text segments. + # + def full_n_segments: () -> Integer + + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t0(3) # => 1668 (16680 ms) + # + def full_get_segment_t0: (Integer segment_index) -> Integer + + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t1(3) # => 1668 (16680 ms) + # + def full_get_segment_t1: (Integer segment_index) -> Integer + + # Text of a segment indexed by `segment_index`. + # + # full_get_segment_text(3) # => "ask not what your country can do for you, ..." + # + def full_get_segment_text: (Integer segment_index) -> String + + # Number of tokens in the segment indexed by `segment_index`. + # + def full_n_tokens: (Integer segment_index) -> Integer + + # Text of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_text: (Integer segment_index, Integer token_index) -> String + + # Token id of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer + + # Probability of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_p: (Integer segment_index, Integer token_index) -> Float + + # Token data of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_data: (Integer segment_index, Integer token_index) -> Token + + def model: () -> Model + + # Yields each Whisper::Parakeet::Segment: + # + # parakeet.transcribe("path/to/audio.wav", params) + # parakeet.each_segment do |segment| + # puts segment.text + # end + # + # Returns an `Enumerator` if no block given: + # + # parakeet.transcribe("path/to/audio.wav", params) + # enum = parakeet.each_segment + # enum.to_a # => [#<Whisper::Parakeet::Segment>, ...] + # + def each_segment: { (Segment) -> void } -> void + | () -> Enumerator[Segment] + + class Params + def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance + def use_gpu: () -> boolish + def use_gpu=: (boolish) -> boolish + def gpu_device: () -> Integer + def gpu_device=: (Integer) -> Integer + end + end + + class Params + def self.new: ( + ?n_threads: Integer, + ?offset_ms: Integer, + ?duration_ms: Integer, + ?no_context: boolish, + ?audio_ctx: Integer, + ?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void, + ?new_segment_callback_user_data: Object, + ?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void, + ?new_token_callback_user_data: Object, + ?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void, + ?progress_callback_user_data: Object, + ?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish, + ?encoder_begin_callback_user_data: Object, + ?abort_callback: ^(Object user_data) -> boolish, + ?abort_callback_user_data: Object + ) -> instance + + # Number of threads to use. + # + def n_threads=: (Integer) -> Integer + def n_threads: () -> Integer + + # Start offset in ms. + # + def offset_ms=: (Integer) -> Integer + def offset_ms: () -> Integer + + # Audio duration to process in ms. + # + def duration_ms=: (Integer) -> Integer + def duration_ms: () -> Integer + + # If `true`, does not use past transcription (if any) as context. + # + def no_context=: (boolish) -> boolish + def no_context: () -> (true | false) + + # Overwrite the audio context size. `0` uses the default value. + # + def audio_ctx=: (Integer) -> Integer + def audio_ctx: () -> Integer + + # Sets new segment callback, called for every newly generated text segment. + # + # params.new_segment_callback = ->(context, _, n_new, user_data) { + # # ... + # } + # + def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) + def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of new segment callback. + # + def new_segment_callback_user_data=: (Object?) -> Object? + def new_segment_callback_user_data: () -> Object? + + # Sets token callback, called for every newly predicted token. + # + def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) + def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of token callback. + # + def new_token_callback_user_data=: (Object?) -> Object? + def new_token_callback_user_data: () -> Object? + + # Sets progress callback, called on each progress update. + # + # +progress+ is an Integer between 0 and 100. + # + def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) + def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of progress callback. + # + def progress_callback_user_data=: (Object?) -> Object? + def progress_callback_user_data: () -> Object? + + # Sets encoder begin callback, called each time before the encoder starts. + # + # If it returns `false`, the computation is aborted. + # + def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) + def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of encoder begin callback. + # + def encoder_begin_callback_user_data=: (Object?) -> Object? + def encoder_begin_callback_user_data: () -> Object? + + # Sets abort callback, called each time before ggml computation starts. + # + def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish) + def abort_callback: () -> ((^(Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of abort callback. + # + def abort_callback_user_data=: (Object?) -> Object? + def abort_callback_user_data: () -> Object? + + # Hook called on new segment. Yields each Whisper::Parakeet::Segment. + # + def on_new_segment: { (Segment) -> void } -> void + + # Hook called on new token. Yields each Whisper::Parakeet::Token. + # + def on_new_token: { (Token) -> void } -> void + + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. + # + def on_progress: { (Integer progress) -> void } -> void + + # Hook called each time before the encoder starts. + # + def on_encoder_begin: { () -> boolish } -> void + + # Call block to determine whether abort or not. Return `true` when you want to abort. + # + def abort_on: { () -> boolish } -> void + end + + class Segment + include Output::Segment + + type deconstructed_keys = { + start_time: (Integer | nil), + end_time: (Integer | nil), + text: (String | nil) + } + + # Start time in milliseconds. + # + def start_time: () -> Integer + + # End time in milliseconds. + # + def end_time: () -> Integer + + # Text of the segment. + # + def text: () -> String + + # Yields each Whisper::Parakeet::Token: + # + # parakeet.each_segment.first.each_token do |token| + # p token + # end + # + # Returns an `Enumerator` if no block is given: + # + # parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...] + # + def each_token: { (Token) -> void } -> void + | () -> Enumerator[Token] + + # Possible keys: `:start_time`, `:end_time`, `:text` + # + def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys + end + + class Token + type deconstructed_keys = { + id: (Integer | nil), + duration_idx: (Integer | nil), + duration_value: (Integer | nil), + frame_index: (Integer | nil), + probability: (Float | nil), + log_probability: (Float | nil), + start_time: (Integer | nil), + end_time: (Integer | nil), + word_start: ((true | false) | nil), + text: (String | nil), + } + + # Token ID. + # + def id: () -> Integer + + # Index into the model's durations array. + # + def duration_idx: () -> Integer + + # Actual duration value. + # + def duration_value: () -> Integer + + # Frame index of the token. + # + def frame_index: () -> Integer + + # Probability of the token. + # + def probability: () -> Float + + # Log probability of the token. + # + def log_probability: () -> Float + + # Start time of the token in milliseconds. + # + def start_time: () -> Integer + + # End time of the token in milliseconds. + # + def end_time: () -> Integer + + # Whether this token is the start of a word. + # + def word_start?: () -> (true | false) + + # Get the token text of the token. + # + def text: () -> String + + def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys + end + + class Model + def n_vocab: () -> Integer + def n_audio_ctx: () -> Integer + def n_audio_state: () -> Integer + def n_audio_head: () -> Integer + def n_audio_layer: () -> Integer + def n_mels: () -> Integer + def ftype: () -> Integer + end + end + module VAD class Params def self.new: ( diff --git a/bindings/ruby/test/helper.rb b/bindings/ruby/test/helper.rb index 56cd3849fdd..5e37ad98596 100644 --- a/bindings/ruby/test/helper.rb +++ b/bindings/ruby/test/helper.rb @@ -5,6 +5,8 @@ class TestBase < Test::Unit::TestCase AUDIO = File.join(__dir__, "fixtures", "jfk.wav") + Parakeet = Whisper::Parakeet + class << self def whisper return @whisper if @whisper diff --git a/bindings/ruby/test/test_callback.rb b/bindings/ruby/test/test_callback.rb index a7f49245ade..6490c8abb48 100644 --- a/bindings/ruby/test/test_callback.rb +++ b/bindings/ruby/test/test_callback.rb @@ -129,6 +129,7 @@ def test_encoder_begin_callback_abort return false } @whisper.transcribe(@audio, @params) + sleep 0.5 # wait for logs dequeued assert_match(/encoder_begin_callback returned false - aborting/, logs.join) Whisper.log_set ->(level, buffer, user_data) {}, nil end diff --git a/bindings/ruby/test/test_parakeet.rb b/bindings/ruby/test/test_parakeet.rb new file mode 100644 index 00000000000..bfd57076f56 --- /dev/null +++ b/bindings/ruby/test/test_parakeet.rb @@ -0,0 +1,28 @@ +require_relative "helper" +require "stringio" + +class TestParakeet < TestBase + def test_log_set + log_callback = Parakeet.instance_variable_get("@log_callback") + user_data = Parakeet.instance_variable_get("@log_callback_user_data") + + $stdout = StringIO.new + Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil + Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + sleep 0.1 + $stdout.rewind + logs = $stdout.string + assert_match /loading model from/, logs + ensure + $stdout = STDOUT + Parakeet.log_set log_callback, user_data + end + + def test_system_info_str + assert_match /\APARAKEET : /, Parakeet.system_info_str + end + + def test_version + assert_instance_of String, Parakeet::VERSION + end +end diff --git a/bindings/ruby/test/test_parakeet_callback.rb b/bindings/ruby/test/test_parakeet_callback.rb new file mode 100644 index 00000000000..1209e960f09 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_callback.rb @@ -0,0 +1,107 @@ +require_relative "helper" + +class TestParakeetCallback < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + @params = Parakeet::Params.new + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @parakeet, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0 + end + } + + @parakeet.transcribe AUDIO, @params + end + + def test_on_new_segment + seg = nil + index = 0 + @params.on_new_segment do |segment| + assert_instance_of Parakeet::Segment, segment + if index == 0 + seg = segment + assert_equal 0, segment.start_time + assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text) + end + index += 1 + end + @parakeet.transcribe AUDIO, @params + assert_equal 0, seg.start_time + assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text + end + + def test_on_new_token + index = 0 + @params.on_new_token do |token| + assert_instance_of Parakeet::Token, token + if index == 0 + assert_instance_of Integer, token.start_time + assert_match "▁And", token.text + end + index += 1 + end + + @parakeet.transcribe AUDIO, @params + end + + def test_on_progress + first = nil + @params.on_progress do |progress| + assert_kind_of Integer, progress + assert 0 <= progress && progress <= 100 + first = progress if first.nil? + end + + @parakeet.transcribe AUDIO, @params + + assert_equal 0, first + end + + def test_on_encoder_begin + i = 0 + @params.on_encoder_begin do + i += 1 + end + + @parakeet.transcribe AUDIO, @params + + assert i > 0 + end + + def test_abort_on + do_abort = false + @params.on_new_segment do |segment| + do_abort = true if segment.text.match?(/ask/) + end + i = 0 + @params.abort_on do + i += 1 + do_abort + end + + @parakeet.transcribe(AUDIO, @params) rescue nil + + assert i > 0 + end +end diff --git a/bindings/ruby/test/test_parakeet_context.rb b/bindings/ruby/test/test_parakeet_context.rb new file mode 100644 index 00000000000..2d039ce75f5 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context.rb @@ -0,0 +1,116 @@ +require_relative "helper" +require "stringio" + +class TestParakeetContext < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Context, @parakeet + end + + def test_new_with_params + log_callback = Parakeet.instance_variable_get(:@log_callback) + user_data = Parakeet.instance_variable_get(:@log_callback_user_data) + begin + logs = "" + Parakeet.log_set proc {|level, message| logs << message}, nil + params = Parakeet::Context::Params.new(use_gpu: false) + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params) + assert_instance_of Parakeet::Context, parakeet + assert_match /use gpu\s+=\s+0/, logs + ensure + Parakeet.log_set log_callback, user_data + end + end + + sub_test_case "full" do + def setup + super + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} + end + + def test_full + @parakeet.full @params, @samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text + end + + def test_full_without_length + @parakeet.full(@params, @samples) + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator + samples = @samples.each + @parakeet.full @params, samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @parakeet.full @params, samples + end + end + + def test_full_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @parakeet.full @params, samples, 11 + end + end + + def test_full_with_memory_view + samples = JFKReader.new(AUDIO) + @parakeet.full @params, samples + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @parakeet.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + end + + def test_transcribe + assert_nothing_raised do + @parakeet.transcribe AUDIO, @params + end + end + + def test_transcribe_with_pathname + assert_nothing_raised do + @parakeet.transcribe Pathname(AUDIO), @params + end + end + + def test_transcribe_with_nothing + assert_raise_message(/open/) do + @parakeet.transcribe "nothing", @params + end + end +end diff --git a/bindings/ruby/test/test_parakeet_context_params.rb b/bindings/ruby/test/test_parakeet_context_params.rb new file mode 100644 index 00000000000..fcd0f2410f7 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context_params.rb @@ -0,0 +1,24 @@ +require_relative "helper" + +class TestParakeetContextParams < TestBase + def setup + @params = Parakeet::Context::Params.new + end + + def test_new + assert_instance_of Parakeet::Context::Params, @params + end + + def test_attributes + assert_true @params.use_gpu + assert_instance_of Integer, @params.gpu_device + end + + def test_attribute_writer + @params.use_gpu = false + assert_false @params.use_gpu + + @params.gpu_device = 2 + assert_equal 2, @params.gpu_device + end +end diff --git a/bindings/ruby/test/test_parakeet_model.rb b/bindings/ruby/test/test_parakeet_model.rb new file mode 100644 index 00000000000..5343b35ed8e --- /dev/null +++ b/bindings/ruby/test/test_parakeet_model.rb @@ -0,0 +1,21 @@ +require_relative "helper" + +class TestParakeetModel < TestBase + def test_model + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + assert_instance_of Parakeet::Model, parakeet.model + end + + def test_attributes + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + model = parakeet.model + + assert_equal 10, model.n_vocab + assert_equal 3200, model.n_audio_ctx + assert_equal 8, model.n_audio_state + assert_equal 2, model.n_audio_head + assert_equal 1, model.n_audio_layer + assert_equal 16, model.n_mels + assert_equal 0, model.ftype + end +end diff --git a/bindings/ruby/test/test_parakeet_params.rb b/bindings/ruby/test/test_parakeet_params.rb new file mode 100644 index 00000000000..dc651f7ab12 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_params.rb @@ -0,0 +1,78 @@ +require_relative "helper" +require "etc" + +class TestParakeetParams < TestBase + PARAM_NAMES = [ + :n_threads, + :offset_ms, + :duration_ms, + :no_context, + :audio_ctx + ] + + def setup + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Params, @params + end + + def test_n_threads + assert_equal [4, Etc.nprocessors].min, @params.n_threads + + @params.n_threads = 1 + assert_equal 1, @params.n_threads + end + + def test_offset_ms + assert_equal 0, @params.offset_ms + + @params.offset_ms = 10_000 + assert_equal 10_000, @params.offset_ms + end + + def test_duration_ms + assert_equal 0, @params.duration_ms + + @params.duration_ms = 60_000 + assert_equal 60_000, @params.duration_ms + end + + def test_no_context + assert_equal true, @params.no_context + + @params.no_context = false + assert_equal false, @params.no_context + end + + def test_audio_ctx + assert_equal 0, @params.audio_ctx + + @params.audio_ctx = 1 + assert_equal 1, @params.audio_ctx + end + + def test_new_with_kw_args + params = Parakeet::Params.new(n_threads: 1) + assert_equal 1, params.n_threads + assert_equal 0, params.offset_ms + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_value = @params.send(param) + value = case [param, default_value] + in [*, true | false] + !default_value + in [*, Integer] + default_value + 1 + end + params = Parakeet::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + assert_equal @params.send(name), params.send(name) + end + end +end diff --git a/bindings/ruby/test/test_parakeet_segment.rb b/bindings/ruby/test/test_parakeet_segment.rb new file mode 100644 index 00000000000..d5b99bd5ee6 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_segment.rb @@ -0,0 +1,42 @@ +require_relative "helper" + +class TestParakeetSegment < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @parakeet.transcribe AUDIO, Parakeet::Params.new + end + + def test_segment + whole_text = "" + @parakeet.each_segment do |segment| + assert_instance_of Parakeet::Segment, segment + assert_kind_of Integer, segment.start_time + assert segment.end_time >= segment.start_time + assert_kind_of String, segment.text + whole_text << segment.text + end + assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text) + end + + def test_deconstruct_keys + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text]) + end + + def test_deconstruct_keys_with_nil + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys(nil) + end +end diff --git a/bindings/ruby/test/test_parakeet_token.rb b/bindings/ruby/test/test_parakeet_token.rb new file mode 100644 index 00000000000..6f0b8b5a37c --- /dev/null +++ b/bindings/ruby/test/test_parakeet_token.rb @@ -0,0 +1,73 @@ +require_relative "helper" + +class TestParakeetToken < TestBase + ATTRS = %i[ + id + duration_idx + duration_value + frame_index + probability + log_probability + start_time + end_time + word_start? + text + ] + + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + params = Parakeet::Params.new + parakeet.transcribe AUDIO, params + @segment = parakeet.each_segment.first + end + + def test_each_token + i = 0 + @segment.each_token do |token| + i += 1 + assert_instance_of Parakeet::Token, token + end + assert_equal 38, i + end + + def test_each_token_without_block + assert_instance_of Enumerator, @segment.each_token + end + + def test_token + token = @segment.each_token.first + + assert_instance_of Parakeet::Token, token + assert_instance_of Integer, token.id + assert_instance_of Integer, token.duration_idx + assert_instance_of Integer, token.duration_value + assert_instance_of Integer, token.frame_index + assert_instance_of Float, token.probability + assert_instance_of Float, token.log_probability + assert_instance_of Integer, token.start_time + assert_instance_of Integer, token.end_time + assert_instance_of String, token.text + end + + def test_text + assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."], + @segment.each_token.collect(&:text) + end + + def test_deconstruct_keys_with_nil + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(nil) + end + + def test_deconstruct_keys_with_keys + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(expected.keys) + end +end diff --git a/bindings/ruby/test/test_vad_segment.rb b/bindings/ruby/test/test_vad_segment.rb index 7348562cb15..6d66c27fd32 100644 --- a/bindings/ruby/test/test_vad_segment.rb +++ b/bindings/ruby/test/test_vad_segment.rb @@ -9,7 +9,7 @@ def test_initialize end assert_raise do - segments.end_time + segment.end_time end assert_raise do diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index f7e25239d5d..082547e7c08 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -149,6 +149,7 @@ def test_log_set } Whisper.log_set log_callback, user_data Whisper::Context.new("base.en") + sleep 0.1 # wait for logs dequeued assert logs.length > 30 logs.each do |log| diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2d952222f29..301ecfcc13d 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -23,7 +23,7 @@ Gem::Specification.new do |s| s.test_files = s.files.select {|file| file.start_with? "test/"} s.extensions << 'ext/extconf.rb' - s.required_ruby_version = '>= 3.1.0' + s.required_ruby_version = '>= 3.3.0' #### Documentation and testing. s.homepage = 'https://github.com/ggml-org/whisper.cpp' diff --git a/ci/run.sh b/ci/run.sh index b03fdf1c6b1..dca4476a0fa 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -151,8 +151,15 @@ function gg_download_model { local cwd=`pwd` mkdir -p "$MNT/models" cd "$MNT/models" + set -x bash "$cwd/models/download-ggml-model.sh" ${model_name} . + local download_status=$? + set +x cd "$cwd" + if [ $download_status -ne 0 ]; then + echo "Error: failed to download model ${model_name}" + ret=1 + fi fi } diff --git a/cmake/parakeet-config.cmake.in b/cmake/parakeet-config.cmake.in new file mode 100644 index 00000000000..aadb55c2d19 --- /dev/null +++ b/cmake/parakeet-config.cmake.in @@ -0,0 +1,30 @@ +set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@) +set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) +set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) +set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@) + +@PACKAGE_INIT@ + +set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@") +set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@") +set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@") + +find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake) + +find_library(parakeet_LIBRARY parakeet + REQUIRED + HINTS ${PARAKEET_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) + +add_library(parakeet UNKNOWN IMPORTED) +set_target_properties(parakeet + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${parakeet_LIBRARY}" + INTERFACE_COMPILE_FEATURES cxx_std_11 + POSITION_INDEPENDENT_CODE ON) + +check_required_components(parakeet) diff --git a/cmake/parakeet.pc.in b/cmake/parakeet.pc.in new file mode 100644 index 00000000000..5a25fbb2e42 --- /dev/null +++ b/cmake/parakeet.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/include + +Name: parakeet +Description: Port of NVIDIA's Parakeet model in C/C++ +Version: @PROJECT_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lparakeet +Cflags: -I${includedir} diff --git a/cmake/whisper.pc.in b/cmake/whisper.pc.in index 73d9f4a0c1e..200179d5d11 100644 --- a/cmake/whisper.pc.in +++ b/cmake/whisper.pc.in @@ -1,7 +1,7 @@ prefix=@CMAKE_INSTALL_PREFIX@ exec_prefix=${prefix} -libdir=${exec_prefix}/lib -includedir=${prefix}/include/whisper +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/include Name: whisper Description: Port of OpenAI's Whisper model in C/C++ diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b202ca00b77..7aedb9df683 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,7 +20,7 @@ set(TARGET common) unset(COMMON_EXTRA_LIBS) -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) # As of cmake 3.27, there is no official cmake support for FindFFmpeg. # Consequnelty we added a FindFFmpeg.cmake script the cmake subfolder: # whisper.cpp does not need the full ffmpeg libs, just AVFORMAT AVCODEC AVUTIL SWRESAMPLE @@ -39,7 +39,7 @@ if (WHISPER_FFMPEG) message(STATUS "Found avformat ${AVFORMAT_VERSION}") include_directories(${FFMPEG_INCLUDE_DIRS}) - add_compile_definitions(WHISPER_FFMPEG) + add_compile_definitions(WHISPER_COMMON_FFMPEG) list(APPEND COMMON_EXTRA_LIBS ${FFMPEG_LIBRARIES}) @@ -107,6 +107,8 @@ else() add_subdirectory(server) add_subdirectory(quantize) add_subdirectory(vad-speech-segments) + add_subdirectory(parakeet-cli) + add_subdirectory(parakeet-quantize) if (WHISPER_SDL2) add_subdirectory(stream) add_subdirectory(command) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 4e84c1b2750..e505bf0e18d 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -151,6 +151,10 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params whisper_print_usage(argc, argv, params); exit(0); } + if (arg == "--version") { + fprintf(stdout, "whisper.cpp version: %s\n", whisper_version()); + exit(0); + } #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } @@ -234,6 +238,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " --version show version information and exit\n"); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); @@ -738,18 +743,47 @@ static void output_json( if (full) { start_arr("tokens"); const int n = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n; ++j) { - auto token = whisper_full_get_token_data(ctx, i, j); + + // Merge adjacent tokens whose bytes together form a + // single UTF-8 codepoint. Multi-byte characters (CJK + // in particular) can end up split across whisper + // tokens, which used to produce invalid UTF-8 in the + // JSON string. Refs issue #1798. + struct merged_token { + std::string text; + whisper_token_data data; + int64_t t1; + }; + std::vector<merged_token> merged; + merged.reserve(n); + for (int j = 0; j < n; ) { + auto tok = whisper_full_get_token_data(ctx, i, j); + merged_token m{ whisper_token_to_str(ctx, tok.id), tok, tok.t1 }; + ++j; + while (j < n && utf8_trailing_bytes_needed(m.text) > 0) { + auto tok_next = whisper_full_get_token_data(ctx, i, j); + m.text += whisper_token_to_str(ctx, tok_next.id); + if (tok_next.t1 > -1) { + m.t1 = tok_next.t1; + } + ++j; + } + merged.push_back(std::move(m)); + } + + const int nm = (int) merged.size(); + for (int j = 0; j < nm; ++j) { + const auto & mt = merged[j]; start_obj(nullptr); - value_s("text", whisper_token_to_str(ctx, token.id), false); - if(token.t0 > -1 && token.t1 > -1) { + value_s("text", mt.text.c_str(), false); + if (mt.data.t0 > -1 && mt.t1 > -1) { // If we have per-token timestamps, write them out - times_o(token.t0, token.t1, false); + times_o(mt.data.t0, mt.t1, false); } - value_i("id", token.id, false); - value_f("p", token.p, false); - value_f("t_dtw", token.t_dtw, true); - end_obj(j == (n - 1)); + value_i("id", mt.data.id, false); + value_f("p", mt.data.p, false); + value_f("t_dtw", mt.data.t_dtw, true); + end_obj(j == (nm - 1)); } end_arr(!params.diarize && !params.tinydiarize); } diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index d29166b50d8..b12481c013f 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -34,9 +34,9 @@ #include <cstring> #include <fstream> -#ifdef WHISPER_FFMPEG -// as implemented in ffmpeg_trancode.cpp only embedded in common lib if whisper built with ffmpeg support -extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector<uint8_t> & wav_data); +#ifdef WHISPER_COMMON_FFMPEG +// as implemented in ffmpeg-trancode.cpp only embedded in common lib if whisper built with ffmpeg support +extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector<uint8_t> & wav_data, int out_sample_rate = WHISPER_SAMPLE_RATE); #endif // extract f32 PCM frames from an initialized decoder, downmix to mono and keep the stereo split @@ -75,7 +75,7 @@ static bool read_audio_from_decoder(ma_decoder & decoder, std::vector<float> & p return true; } -bool read_audio_data(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) { +bool read_audio_data(const std::string & fname, std::vector<float> & pcmf32, std::vector<std::vector<float>> & pcmf32s, bool stereo) { std::vector<uint8_t> audio_data; // used for pipe input from stdin or ffmpeg decoding output ma_result result; @@ -96,53 +96,67 @@ bool read_audio_data(const std::string & fname, std::vector<float>& pcmf32, std: decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); if (fname == "-") { - #ifdef _WIN32 - _setmode(_fileno(stdin), _O_BINARY); - #endif - - uint8_t buf[1024]; - while (true) - { - const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) { - break; - } - audio_data.insert(audio_data.end(), buf, buf + n); - } - - result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); +#ifdef _WIN32 + _setmode(_fileno(stdin), _O_BINARY); +#endif + + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + audio_data.insert(audio_data.end(), buf, buf + n); + } + + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); if (result != MA_SUCCESS) { - fprintf(stderr, "Error: failed to open audio data from stdin (%s)\n", ma_result_description(result)); - return false; - } + fprintf(stderr, "%s: failed to open audio data from stdin (%s)\n", __func__, ma_result_description(result)); + return false; + } decoder.initialized = true; - fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); - } - else { - result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); - if (result == MA_SUCCESS) { - decoder.initialized = true; + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); + } else { + fprintf(stderr, "%s: reading audio data from '%s' ...\n", __func__, fname.c_str()); + + // first try miniaudio. if it fails (or skipped) - try ffmpeg + { + const char * skip = getenv("WHISPER_COMMON_MINIAUDIO_SKIP"); + if (!skip || strlen(skip) == 0 || strcmp(skip, "0") == 0) { + fprintf(stderr, "%s: trying to decode with miniaudio\n", __func__); + + result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); + if (result == MA_SUCCESS) { + decoder.initialized = true; + } + } else { + fprintf(stderr, "%s: skipping miniaudio\n", __func__); + } } -#if defined(WHISPER_FFMPEG) + +#if defined(WHISPER_COMMON_FFMPEG) if (!decoder.initialized) { + fprintf(stderr, "%s: trying to decode with ffmpeg\n", __func__); + if (ffmpeg_decode_audio(fname, audio_data) != 0) { - fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); + fprintf(stderr, "%s: failed to ffmpeg decode\n", __func__); return false; } result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); if (result != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); + fprintf(stderr, "%s: failed to read audio data as wav (%s)\n", __func__, ma_result_description(result)); return false; } decoder.initialized = true; } -#else - if (!decoder.initialized) { - fprintf(stderr, "error: failed to read audio data from (%s)\n", fname.c_str()); - return false; - } #endif + + if (!decoder.initialized) { + fprintf(stderr, "%s: failed to read audio data\n", __func__); + return false; + } } return read_audio_from_decoder(decoder.decoder, pcmf32, pcmf32s, stereo); @@ -184,6 +198,34 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100))); } +int utf8_trailing_bytes_needed(const std::string & s) { + const int n = (int) s.size(); + int i = n - 1; + while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { + --i; + } + if (i < 0) { + return 0; + } + + const unsigned char c = (unsigned char) s[i]; + int expected; + if ((c & 0x80) == 0x00) { + expected = 1; + } else if ((c & 0xE0) == 0xC0) { + expected = 2; + } else if ((c & 0xF0) == 0xE0) { + expected = 3; + } else if ((c & 0xF8) == 0xF0) { + expected = 4; + } else { + return 0; + } + + const int have = n - i; + return have >= expected ? 0 : (expected - have); +} + bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id) { std::ofstream speak_file(path.c_str()); if (speak_file.fail()) { diff --git a/examples/common-whisper.h b/examples/common-whisper.h index 8714c381046..aec430d3635 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -28,5 +28,8 @@ std::string to_timestamp(int64_t t, bool comma = false); // given a timestamp get the sample int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate); +// Returns the number of trailing bytes still needed for s to end on a complete UTF-8 codepoint. +int utf8_trailing_bytes_needed(const std::string & s); + // write text to file, and call system("command voice_id file") bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id); diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp index 1fae58a4ffa..7657af69823 100644 --- a/examples/ffmpeg-transcode.cpp +++ b/examples/ffmpeg-transcode.cpp @@ -1,368 +1,238 @@ -/* SPDX-License-Identifier: GPL-2.0 */ +#ifdef WHISPER_COMMON_FFMPEG -/* - * transcode.c - convert audio file to WAVE - * - * Copyright (C) 2019 Andrew Clayton <andrew@digital-domain.net> - * Copyright (C) 2024 William Tambellini <william.tambellini@gmail.com> - */ - -// Just for conveninent C++ API -#include <vector> #include <string> - -// C -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <stdbool.h> -#include <stdint.h> -#include <sys/types.h> -#include <sys/stat.h> -#include <fcntl.h> -#include <unistd.h> -#include <sys/mman.h> +#include <vector> +#include <cstdio> +#include <cstring> extern "C" { -#include <libavutil/opt.h> -#include <libavcodec/avcodec.h> #include <libavformat/avformat.h> +#include <libavcodec/avcodec.h> #include <libswresample/swresample.h> } -typedef uint64_t u64; -typedef int64_t s64; -typedef uint32_t u32; -typedef int32_t s32; -typedef uint16_t u16; -typedef int16_t s16; -typedef uint8_t u8; -typedef int8_t s8; - -#define WAVE_SAMPLE_RATE 16000 -#define AVIO_CTX_BUF_SZ 4096 - -static const char* ffmpegLog = getenv("FFMPEG_LOG"); -// Todo: add __FILE__ __LINE__ -#define LOG(...) \ - do { if (ffmpegLog) fprintf(stderr, __VA_ARGS__); } while(0) // C99 - -/* - * WAVE file header based on definition from - * https://gist.github.com/Jon-Schneider/8b7c53d27a7a13346a643dac9c19d34f - * - * We must ensure this structure doesn't have any holes or - * padding so we can just map it straight to the WAVE data. - */ -struct wave_hdr { - /* RIFF Header: "RIFF" */ - char riff_header[4]; - /* size of audio data + sizeof(struct wave_hdr) - 8 */ - int wav_size; - /* "WAVE" */ - char wav_header[4]; - - /* Format Header */ - /* "fmt " (includes trailing space) */ - char fmt_header[4]; - /* Should be 16 for PCM */ - int fmt_chunk_size; - /* Should be 1 for PCM. 3 for IEEE Float */ - s16 audio_format; - s16 num_channels; - int sample_rate; - /* - * Number of bytes per second - * sample_rate * num_channels * bit_depth/8 - */ - int byte_rate; - /* num_channels * bytes per sample */ - s16 sample_alignment; - /* bits per sample */ - s16 bit_depth; - - /* Data Header */ - /* "data" */ - char data_header[4]; - /* - * size of audio - * number of samples * num_channels * bit_depth/8 - */ - int data_bytes; -} __attribute__((__packed__)); - -struct audio_buffer { - u8 *ptr; - int size; /* size left in the buffer */ -}; - -static void set_wave_hdr(wave_hdr& wh, size_t size) { - memcpy(&wh.riff_header, "RIFF", 4); - wh.wav_size = size + sizeof(struct wave_hdr) - 8; - memcpy(&wh.wav_header, "WAVE", 4); - memcpy(&wh.fmt_header, "fmt ", 4); - wh.fmt_chunk_size = 16; - wh.audio_format = 1; - wh.num_channels = 1; - wh.sample_rate = WAVE_SAMPLE_RATE; - wh.sample_alignment = 2; - wh.bit_depth = 16; - wh.byte_rate = wh.sample_rate * wh.sample_alignment; - memcpy(&wh.data_header, "data", 4); - wh.data_bytes = size; +// Write a minimal WAV header into the output buffer. +// Returns the number of bytes written (44 for a standard PCM WAV header). +static size_t wav_header_write(uint8_t * buf, int num_channels, int sample_rate, int bits_per_sample, uint32_t data_size) { + // RIFF header + memcpy(buf, "RIFF", 4); + uint32_t chunk_size = 36 + data_size; + memcpy(buf + 4, &chunk_size, 4); + memcpy(buf + 8, "WAVE", 4); + + // fmt subchunk + memcpy(buf + 12, "fmt ", 4); + uint32_t subchunk1_size = 16; + memcpy(buf + 16, &subchunk1_size, 4); + uint16_t audio_format = 1; // PCM + memcpy(buf + 20, &audio_format, 2); + memcpy(buf + 22, &num_channels, 2); + memcpy(buf + 24, &sample_rate, 4); + + int bytes_per_sample = (bits_per_sample / 8) * num_channels; + int byte_rate = sample_rate * bytes_per_sample; + memcpy(buf + 28, &byte_rate, 4); + memcpy(buf + 32, &bytes_per_sample, 2); + memcpy(buf + 34, &bits_per_sample, 2); + + // data subchunk + memcpy(buf + 36, "data", 4); + memcpy(buf + 40, &data_size, 4); + + return 44; } -static void write_wave_hdr(int fd, size_t size) { - struct wave_hdr wh; - set_wave_hdr(wh, size); - write(fd, &wh, sizeof(struct wave_hdr)); -} +bool ffmpeg_decode_audio(const std::string & ifname, std::vector<uint8_t> & wav_data, int out_sample_rate) { + { + const char * verbose = getenv("WHISPER_COMMON_FFMPEG_VERBOSE"); + if (verbose && strcmp(verbose, "2") == 0) { + av_log_set_level(AV_LOG_DEBUG); + } else if (verbose && strcmp(verbose, "1") == 0) { + av_log_set_level(AV_LOG_VERBOSE); + } else { + av_log_set_level(AV_LOG_WARNING); + } + } -static int map_file(int fd, u8 **ptr, size_t *size) -{ - struct stat sb; + AVFormatContext * fmt_ctx = nullptr; + if (avformat_open_input(&fmt_ctx, ifname.c_str(), nullptr, nullptr) != 0) { + fprintf(stderr, "error: failed to open input file '%s'\n", ifname.c_str()); + return true; + } - fstat(fd, &sb); - *size = sb.st_size; + if (avformat_find_stream_info(fmt_ctx, nullptr) < 0) { + fprintf(stderr, "error: failed to find stream information\n"); + avformat_close_input(&fmt_ctx); + return true; + } - *ptr = (u8*)mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); - if (*ptr == MAP_FAILED) { - perror("mmap"); - return -1; - } + // Find the first audio stream + int audio_stream_idx = -1; + for (unsigned int i = 0; i < fmt_ctx->nb_streams; i++) { + if (fmt_ctx->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { + audio_stream_idx = i; + break; + } + } - return 0; -} + if (audio_stream_idx == -1) { + fprintf(stderr, "error: failed to find an audio stream in '%s'\n", ifname.c_str()); + avformat_close_input(&fmt_ctx); + return true; + } -static int read_packet(void *opaque, u8 *buf, int buf_size) -{ - struct audio_buffer *audio_buf = (audio_buffer*)opaque; + AVStream * audio_stream = fmt_ctx->streams[audio_stream_idx]; - buf_size = FFMIN(buf_size, audio_buf->size); + // Open the decoder + const AVCodec * codec = avcodec_find_decoder(audio_stream->codecpar->codec_id); + if (!codec) { + fprintf(stderr, "error: failed to find decoder for codec id %d\n", audio_stream->codecpar->codec_id); + avformat_close_input(&fmt_ctx); + return true; + } - /* copy internal buffer data to buf */ - memcpy(buf, audio_buf->ptr, buf_size); - audio_buf->ptr += buf_size; - audio_buf->size -= buf_size; + AVCodecContext * codec_ctx = avcodec_alloc_context3(codec); + if (!codec_ctx) { + fprintf(stderr, "error: failed to allocate codec context\n"); + avformat_close_input(&fmt_ctx); + return true; + } - return buf_size; -} + if (avcodec_parameters_to_context(codec_ctx, audio_stream->codecpar) < 0) { + fprintf(stderr, "error: failed to copy codec parameters to context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static void convert_frame(struct SwrContext *swr, AVCodecContext *codec, - AVFrame *frame, s16 **data, int *size, bool flush) -{ - int nr_samples; - s64 delay; - u8 *buffer; - - delay = swr_get_delay(swr, codec->sample_rate); - nr_samples = av_rescale_rnd(delay + frame->nb_samples, - WAVE_SAMPLE_RATE, codec->sample_rate, - AV_ROUND_UP); - av_samples_alloc(&buffer, NULL, 1, nr_samples, AV_SAMPLE_FMT_S16, 0); - - /* - * !flush is used to check if we are flushing any remaining - * conversion buffers... - */ - nr_samples = swr_convert(swr, &buffer, nr_samples, - !flush ? (const u8 **)frame->data : NULL, - !flush ? frame->nb_samples : 0); - - *data = (s16*)realloc(*data, (*size + nr_samples) * sizeof(s16)); - memcpy(*data + *size, buffer, nr_samples * sizeof(s16)); - *size += nr_samples; - av_freep(&buffer); -} + if (avcodec_open2(codec_ctx, codec, nullptr) < 0) { + fprintf(stderr, "error: failed to open codec\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static bool is_audio_stream(const AVStream *stream) -{ - if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) - return true; + // Setup resampler: convert to 16-bit signed PCM, mono, 16000 Hz + const enum AVSampleFormat out_sample_fmt = AV_SAMPLE_FMT_S16; - return false; -} + AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; -// Return non zero on error, 0 on success -// audio_buffer: input memory -// data: decoded output audio data (wav file) -// size: size of output data -static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size) -{ - LOG("decode_audio: input size: %d\n", audio_buf->size); - AVFormatContext *fmt_ctx; - AVIOContext *avio_ctx; - AVStream *stream; - AVCodecContext *codec; - AVPacket *packet; - AVFrame *frame; - struct SwrContext *swr; - u8 *avio_ctx_buffer; - unsigned int i; - int stream_index = -1; - int err; - const size_t errbuffsize = 1024; - char errbuff[errbuffsize]; - - fmt_ctx = avformat_alloc_context(); - avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ); - LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ); - avio_ctx = avio_alloc_context(avio_ctx_buffer, AVIO_CTX_BUF_SZ, 0, audio_buf, &read_packet, NULL, NULL); - fmt_ctx->pb = avio_ctx; - - // open the input stream and read header - err = avformat_open_input(&fmt_ctx, NULL, NULL, NULL); - if (err) { - LOG("Could not read audio buffer: %d: %s\n", err, av_make_error_string(errbuff, errbuffsize, err)); - return err; - } - - err = avformat_find_stream_info(fmt_ctx, NULL); - if (err < 0) { - LOG("Could not retrieve stream info from audio buffer: %d\n", err); - return err; - } - - for (i = 0; i < fmt_ctx->nb_streams; i++) { - if (is_audio_stream(fmt_ctx->streams[i])) { - stream_index = i; - break; - } - } - - if (stream_index == -1) { - LOG("Could not retrieve audio stream from buffer\n"); - return -1; - } - - stream = fmt_ctx->streams[stream_index]; - codec = avcodec_alloc_context3( - avcodec_find_decoder(stream->codecpar->codec_id)); - avcodec_parameters_to_context(codec, stream->codecpar); - err = avcodec_open2(codec, avcodec_find_decoder(codec->codec_id), - NULL); - if (err) { - LOG("Failed to open decoder for stream #%d in audio buffer\n", stream_index); - return err; - } - - /* prepare resampler */ - swr = swr_alloc(); - -#if LIBAVCODEC_VERSION_MAJOR > 60 - AVChannelLayout in_ch_layout = codec->ch_layout; - AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; - - /* Set the source audio layout as-is */ - av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - - /* Convert it into 16khz Mono */ - av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#else - av_opt_set_int(swr, "in_channel_count", codec->channels, 0); - av_opt_set_int(swr, "out_channel_count", 1, 0); - av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); - av_opt_set_int(swr, "out_channel_layout", AV_CH_LAYOUT_MONO, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#endif - - swr_init(swr); - if (!swr_is_initialized(swr)) { - LOG("Resampler has not been properly initialized\n"); - return -1; - } - - packet=av_packet_alloc(); - if (!packet) { - LOG("Error allocating the packet\n"); - return -1; - } - frame = av_frame_alloc(); - if (!frame) { - LOG("Error allocating the frame\n"); - return -1; - } - - /* iterate through frames */ - *data = NULL; - *size = 0; - while (av_read_frame(fmt_ctx, packet) >= 0) { - avcodec_send_packet(codec, packet); - - err = avcodec_receive_frame(codec, frame); - if (err == AVERROR(EAGAIN)) - continue; - - convert_frame(swr, codec, frame, data, size, false); - } - /* Flush any remaining conversion buffers... */ - convert_frame(swr, codec, frame, data, size, true); - - av_packet_free(&packet); - av_frame_free(&frame); - swr_free(&swr); - //avio_context_free(); // todo? - avcodec_free_context(&codec); - avformat_close_input(&fmt_ctx); - avformat_free_context(fmt_ctx); - - if (avio_ctx) { - av_freep(&avio_ctx->buffer); - av_freep(&avio_ctx); - } - - return 0; -} + SwrContext * swr_ctx = nullptr; + if (swr_alloc_set_opts2(&swr_ctx, &out_ch_layout, out_sample_fmt, out_sample_rate, + &codec_ctx->ch_layout, codec_ctx->sample_fmt, codec_ctx->sample_rate, + 0, nullptr) < 0) { + fprintf(stderr, "error: failed to allocate swr context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -// in mem decoding/conversion/resampling: -// ifname: input file path -// owav_data: in mem wav file. Can be forwarded as it to whisper/drwav -// return 0 on success -int ffmpeg_decode_audio(const std::string &ifname, std::vector<uint8_t>& owav_data) { - LOG("ffmpeg_decode_audio: %s\n", ifname.c_str()); - int ifd = open(ifname.c_str(), O_RDONLY); - if (ifd == -1) { - fprintf(stderr, "Couldn't open input file %s\n", ifname.c_str()); - return -1; + if (swr_init(swr_ctx) < 0) { + fprintf(stderr, "error: failed to initialize swr context\n"); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; } - u8 *ibuf = NULL; - size_t ibuf_size; - int err = map_file(ifd, &ibuf, &ibuf_size); - if (err) { - LOG("Couldn't map input file %s\n", ifname.c_str()); - return err; + + // Decode and resample + AVPacket * packet = av_packet_alloc(); + AVFrame * frame = av_frame_alloc(); + + // Buffer to collect resampled output + std::vector<int16_t> pcm_data; + + // Max output samples per swr_convert call + const int max_out_samples = 16 * 1024; + std::vector<int16_t> out_buffer(max_out_samples); + + while (av_read_frame(fmt_ctx, packet) >= 0) { + if (packet->stream_index != audio_stream_idx) { + av_packet_unref(packet); + continue; + } + + int ret = avcodec_send_packet(codec_ctx, packet); + av_packet_unref(packet); + + if (ret < 0) { + continue; + } + + while (ret >= 0) { + ret = avcodec_receive_frame(codec_ctx, frame); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { + break; + } + if (ret < 0) { + break; + } + + // Resample + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } + } } - LOG("Mapped input file: %s size: %d\n", ibuf, (int) ibuf_size); - struct audio_buffer inaudio_buf; - inaudio_buf.ptr = ibuf; - inaudio_buf.size = ibuf_size; - - s16 *odata=NULL; - int osize=0; - - err = decode_audio(&inaudio_buf, &odata, &osize); - LOG("decode_audio returned %d \n", err); - if (err != 0) { - LOG("decode_audio failed\n"); - return err; + + // Flush the decoder + avcodec_send_packet(codec_ctx, nullptr); + while (avcodec_receive_frame(codec_ctx, frame) >= 0) { + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } } - LOG("decode_audio output size: %d\n", osize); - - wave_hdr wh; - const size_t outdatasize = osize * sizeof(s16); - set_wave_hdr(wh, outdatasize); - owav_data.resize(sizeof(wave_hdr) + outdatasize); - // header: - memcpy(owav_data.data(), &wh, sizeof(wave_hdr)); - // the data: - memcpy(owav_data.data() + sizeof(wave_hdr), odata, osize* sizeof(s16)); - - return 0; + + // Flush the resampler + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + int flush_samples = swr_convert(swr_ctx, out_data, max_out_samples, nullptr, 0); + if (flush_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + flush_samples); + } + + // Build WAV output + uint32_t data_size = pcm_data.size() * sizeof(int16_t); + wav_data.resize(44 + data_size); + + wav_header_write(wav_data.data(), 1, out_sample_rate, 16, data_size); + memcpy(wav_data.data() + 44, pcm_data.data(), data_size); + + // Cleanup + av_frame_free(&frame); + av_packet_free(&packet); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + + return false; // success } + +#endif // WHISPER_COMMON_FFMPEG diff --git a/examples/parakeet-cli/CMakeLists.txt b/examples/parakeet-cli/CMakeLists.txt new file mode 100644 index 00000000000..adb9aba38ef --- /dev/null +++ b/examples/parakeet-cli/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET parakeet-cli) +add_executable(${TARGET} parakeet-cli.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md new file mode 100644 index 00000000000..ccb8404f542 --- /dev/null +++ b/examples/parakeet-cli/README.md @@ -0,0 +1,106 @@ +# whisper.cpp/examples/parakeet-cli + +This is an example of using the [Parakeet] model in whisper.cpp. + +### Download converted model +```console +$ hf download ggml-org/parakeet-GGUF parakeet-tdt-0.6b-v3-f16.bin --local-dir models +``` + +### Building +```console +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +### Usage +```console +$ ./build/bin/parakeet-cli --help + +usage: ./build/bin/parakeet-cli [options] file0 file1 ... +supported audio formats: flac, mp3, ogg, wav + +options: + -h, --help [default] show this help message and exit + -t N, --threads N [4 ] number of threads to use during computation + -m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path + -f, --file FILE [ ] input audio file + -ng, --no-gpu [false ] disable GPU + -dev N, --device N [0 ] GPU device to use + -ps, --print-segments [false ] print segment information +``` + +### Example +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. +``` + +To print segment information: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + +Segments (1): +Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country." +Tokens [38]: + [ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And" + [ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so" + [ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false "," + [ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my" + [ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f" + [ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell" + [ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow" + [ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer" + [ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic" + [ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans" + [10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false "," + [11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a" + [12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk" + [13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not" + [14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what" + [15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your" + [16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co" + [17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un" + [18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr" + [19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y" + [20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can" + [21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do" + [22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for" + [23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you" + [24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false "," + [25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a" + [26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk" + [27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what" + [28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you" + [29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can" + [30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do" + [31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for" + [32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your" + [33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co" + [34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un" + [35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr" + [36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y" + [37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "." +``` + +### Model conversion +Clone the original model from Hugging Face: +```console +$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 +``` +Convert the model: +```console +(venv) $ python models/convert-parakeet-to-ggml.py \ + --model <path to cloned model> \ + --out-dir models \ + --out-name ggml-parakeet-tdt-0.6b-v3-f16.bin +``` + +[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp new file mode 100644 index 00000000000..03ddc7f8b8c --- /dev/null +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -0,0 +1,243 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> +#include <thread> +#include <vector> +#include <cstring> +#include <fstream> + +// command-line parameters +struct parakeet_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + bool use_gpu = true; + int32_t gpu_device = 0; + + bool print_segments = false; + bool output_txt = false; + bool no_prints = false; + + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string output_file = ""; + std::vector<std::string> fname_inp = {}; +}; + +static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params); + +static char * requires_value_error(const std::string & arg) { + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); + exit(1); +} + +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + parakeet_print_usage(argc, argv, params); + exit(0); + } + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } + else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + parakeet_print_usage(argc, argv, params); + exit(1); + } + } + + return true; +} + +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]); + fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); + fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, "\n"); +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + bool * is_first = (bool *) user_data; + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf)); + printf("%s", text_buf); + fflush(stdout); + + *is_first = false; +} + +static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + parakeet_params params; + + if (parakeet_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.no_prints) { + parakeet_log_set(cb_log_disable, NULL); + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + parakeet_print_usage(argc, argv, params); + return 1; + } + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.gpu_device = params.gpu_device; + + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } + + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + } + + // Process each input file + for (const auto & fname : params.fname_inp) { + if (!params.no_prints) { + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + } + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str()); + continue; + } + + if (pcmf32.empty()) { + fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str()); + continue; + } + + bool is_first = true; + struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + full_params.n_threads = params.n_threads; + full_params.new_token_callback = token_callback; + full_params.new_token_callback_user_data = &is_first; + + const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); + int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); + + if (ret != 0) { + fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); + continue; + } + + printf("\n"); + + if (params.output_txt) { + const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt"; + + std::ofstream fout(fname_out); + if (fout.is_open()) { + const int n_segments = parakeet_full_n_segments(pctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = parakeet_full_get_segment_text(pctx, i); + fout << text << "\n"; + } + fout.close(); + if (!params.no_prints) { + fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + } + } else { + fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str()); + } + } + + if (!params.no_prints) { + parakeet_print_timings(pctx); + } + + if (params.print_segments) { + const int n_segments = parakeet_full_n_segments(pctx); + fprintf(stderr, "\nSegments (%d):\n", n_segments); + + for (int i = 0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text(pctx, i); + const int64_t t0 = parakeet_full_get_segment_t0(pctx, i); + const int64_t t1 = parakeet_full_get_segment_t1(pctx, i); + const int n_tokens = parakeet_full_n_tokens(pctx, i); + + fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + fprintf(stderr, "Tokens [%d]:\n", n_tokens); + + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j); + const char * token_str = parakeet_token_to_str(pctx, token_data.id); + + fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start ? "true": "false", + token_str); + } + } + } + } + + parakeet_free(pctx); + + return 0; +} diff --git a/examples/parakeet-quantize/CMakeLists.txt b/examples/parakeet-quantize/CMakeLists.txt new file mode 100644 index 00000000000..6b46da18d27 --- /dev/null +++ b/examples/parakeet-quantize/CMakeLists.txt @@ -0,0 +1,7 @@ +set(TARGET parakeet-quantize) +add_executable(${TARGET} parakeet-quantize.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${CMAKE_THREAD_LIBS_INIT}) +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-quantize/parakeet-quantize.cpp b/examples/parakeet-quantize/parakeet-quantize.cpp new file mode 100644 index 00000000000..a5d9616420f --- /dev/null +++ b/examples/parakeet-quantize/parakeet-quantize.cpp @@ -0,0 +1,230 @@ +#include "ggml.h" +#include "ggml-backend.h" + +#include "common-ggml.h" + +#include <cassert> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <string> +#include <vector> + +struct parakeet_hparams { + int32_t n_vocab = 0; + int32_t n_audio_ctx = 0; + int32_t n_audio_state = 0; + int32_t n_audio_head = 0; + int32_t n_audio_layer = 0; + int32_t n_mels = 0; + int32_t ftype = 0; + int32_t n_fft = 0; + int32_t subsampling_factor = 0; + int32_t n_subsampling_channels = 0; + int32_t n_conv_kernel = 0; + int32_t n_pred_dim = 0; + int32_t n_pred_layers = 0; + int32_t n_tdt_durations = 0; + int32_t n_max_tokens = 0; +}; + +static bool parakeet_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__); + return false; + } + fout.write((char *) &magic, sizeof(magic)); + } + + // hparams + parakeet_hparams hparams; + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + finp.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + finp.read((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + finp.read((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + finp.read((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + finp.read((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + finp.read((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + finp.read((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + finp.read((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + finp.read((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + + const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR; + const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype; + + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); + fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype); + fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src); + fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst); + fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + fout.write((char *) &ftype_dst, sizeof(ftype_dst)); + fout.write((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + fout.write((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + fout.write((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + fout.write((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + fout.write((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + fout.write((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + fout.write((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + fout.write((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + } + + // mel filterbank + { + int32_t n_mel, n_fb; + finp.read((char *) &n_mel, sizeof(n_mel)); + fout.write((char *) &n_mel, sizeof(n_mel)); + finp.read((char *) &n_fb, sizeof(n_fb)); + fout.write((char *) &n_fb, sizeof(n_fb)); + + const size_t n = (size_t) n_mel * n_fb; + std::vector<float> buf(n); + finp.read((char *) buf.data(), n * sizeof(float)); + fout.write((char *) buf.data(), n * sizeof(float)); + } + + // window function + { + int32_t n_window; + finp.read((char *) &n_window, sizeof(n_window)); + fout.write((char *) &n_window, sizeof(n_window)); + + std::vector<float> buf(n_window); + finp.read((char *) buf.data(), n_window * sizeof(float)); + fout.write((char *) buf.data(), n_window * sizeof(float)); + } + + // TDT durations + { + std::vector<uint32_t> buf(hparams.n_tdt_durations); + finp.read((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + fout.write((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + } + + // vocab + { + int32_t n_tokens; + finp.read((char *) &n_tokens, sizeof(n_tokens)); + fout.write((char *) &n_tokens, sizeof(n_tokens)); + + for (int i = 0; i < n_tokens; ++i) { + int32_t len; + finp.read((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + std::string token(len, '\0'); + finp.read(&token[0], len); + fout.write(&token[0], len); + } + } + + // tensors — quantize 2D weights skipping tensors that must stay F32: + // ggml_ssm_conv / ggml_conv2d_dw CUDA kernels require F32 weights. + // pos_bias_u / pos_bias_v are declared F32 in the loader. + const std::vector<std::string> to_quant = { ".*" }; + std::vector<std::string> to_skip = { + // CUDA kernel constraints (ggml_ssm_conv / ggml_conv2d_dw require F32 weights) + "encoder\\.layers\\..+\\.conv\\.depthwise_conv\\.weight", + // Declared F32 in loader (pos_bias tensors) + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_u", + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_v", + }; + + // Prediction/joint tensors use n_pred_dim as their inner dimension. K-quant + // types (block size 256) cannot quantize 640 evenly, so keep them F32. For + // other types (Q8_0, Q4_0, block size 32) 640 is divisible and they can be + // quantized normally. The loader mirrors this logic at load time. + { + const ggml_type qtype = ggml_ftype_to_ggml_type(ftype); + const int32_t blck = ggml_blck_size(qtype); + if (blck > 1 && hparams.n_pred_dim % blck != 0) { + to_skip.push_back("decoder\\.prediction\\.embed\\.weight"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_ih_l.*"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_hh_l.*"); + to_skip.push_back("joint\\.pred\\.weight"); + to_skip.push_back("joint\\.joint_net\\.2\\.weight"); + } + } + + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, to_skip)) { + fprintf(stderr, "%s: failed to quantize tensors\n", __func__); + return false; + } + + finp.close(); + fout.close(); + + return true; +} + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + // initialise F16 lookup tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); + + if (ftype == GGML_FTYPE_UNKNOWN) { + fprintf(stderr, "%s: invalid quantization type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + const int64_t t_start_us = ggml_time_us(); + + if (!parakeet_model_quantize(fname_inp, fname_out, ftype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", argv[0], fname_inp.c_str()); + return 1; + } + + printf("\n%s: quantize time = %8.2f ms\n", argv[0], (ggml_time_us() - t_start_us) / 1000.0f); + printf("%s: output model = %s\n", argv[0], fname_out.c_str()); + + return 0; +} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index aae74c3d840..b87ef27375f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1107,10 +1107,29 @@ int main(int argc, char ** argv) { } segment["tokens"].push_back(token.id); - json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; + std::string word_text = whisper_full_get_token_text(ctx, i, j); + int64_t word_t1 = token.t1; + + while (j + 1 < n_tokens && utf8_trailing_bytes_needed(word_text) > 0) { + const whisper_token_data next_token = whisper_full_get_token_data(ctx, i, j + 1); + // Keep verbose_json tokens free of EOT ids, matching the pre-merge server behavior. + if (next_token.id >= whisper_token_eot(ctx)) { + break; + } + + ++j; + segment["tokens"].push_back(next_token.id); + word_text += whisper_full_get_token_text(ctx, i, j); + if (next_token.t1 > -1) { + word_t1 = next_token.t1; + } + total_logprob += next_token.plog; + } + + json word = json{{"word", word_text}}; if (!params.no_timestamps && params.token_timestamps) { word["start"] = token.t0 * 0.01; - word["end"] = token.t1 * 0.01; + word["end"] = word_t1 * 0.01; word["t_dtw"] = token.t_dtw; } word["probability"] = token.p; diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 1adeef8f511..13b284ed0e9 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -20,6 +20,7 @@ if (WHISPER_SDL2) llama-io.cpp llama-kv-cache.cpp llama-kv-cache-iswa.cpp + llama-kv-cache-dsa.cpp llama-memory-recurrent.cpp llama-memory-hybrid.cpp llama-memory-hybrid-iswa.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index 4a1aaa955a8..3e0fe66afff 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -41,7 +41,7 @@ bool llama_adapter_cvec::init(const llama_model & model) { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), + /*.mem_size =*/ hparams.n_layer()*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -61,9 +61,9 @@ bool llama_adapter_cvec::init(const llama_model & model) { }; // make tensors - tensors.reserve(hparams.n_layer); + tensors.reserve(hparams.n_layer()); tensors.push_back(nullptr); // there's never a tensor for layer 0 - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -121,7 +121,7 @@ bool llama_adapter_cvec::apply( layer_start = il_start; layer_end = il_end; - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { assert(tensors[il] != nullptr); const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index c9eead18aa3..9f93d5bc7ce 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -3,7 +3,6 @@ #include "llama-impl.h" #include <map> -#include <set> #include <vector> static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { @@ -57,6 +56,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, + { LLM_ARCH_DEEPSEEK32, "deepseek32" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -126,6 +127,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_EAGLE3, "eagle3" }, { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, @@ -133,6 +135,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, + { LLM_ARCH_TALKIE, "talkie" }, + { LLM_ARCH_MELLUM, "mellum" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -193,6 +197,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, + { LLM_KV_DEEPSTACK_MAPPING, "%s.deepstack_mapping" }, + { LLM_KV_HIDDEN_ACT, "%s.hidden_activation" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -243,6 +249,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, + { LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, @@ -285,44 +292,51 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_TARGET_LAYERS, "%s.target_layers" }, + { LLM_KV_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_NORM_BEFORE_RESIDUAL, "%s.norm_before_residual" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, - { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, - { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, - { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, - - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, - { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, - { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, - { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, - { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, - { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, - { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, - { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, + { LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, "tokenizer.ggml.normalizer.strip_accents" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, @@ -445,6 +459,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" }, + { LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" }, { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, @@ -548,6 +564,10 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + { LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" }, + { LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" }, + { LLM_TENSOR_FC, "fc" }, + { LLM_TENSOR_D2T, "d2t" }, }; // declare information about the model weight tensors: @@ -757,6 +777,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so // the model loader doesn't fault on the block index. @@ -767,8 +789,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super - {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU + {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + {LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + // eagle3 + {LLM_TENSOR_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -902,6 +930,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) { case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_GLM_DSA: case LLM_ARCH_BITNET: case LLM_ARCH_T5: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 89cf16cc37c..c5245fb5891 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -61,6 +61,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_ASSISTANT, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -79,6 +80,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2OCR, + LLM_ARCH_DEEPSEEK32, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -137,6 +139,9 @@ enum llm_arch { LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, + LLM_ARCH_TALKIE, + LLM_ARCH_MELLUM, + LLM_ARCH_EAGLE3, LLM_ARCH_UNKNOWN, }; @@ -197,6 +202,8 @@ enum llm_kv { LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, + LLM_KV_DEEPSTACK_MAPPING, + LLM_KV_HIDDEN_ACT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -247,6 +254,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ATTENTION_SHARED_KV_LAYERS, + LLM_KV_ATTENTION_RECURRENT_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_SWA, @@ -306,12 +314,15 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, + LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_PAD_ID, LLM_KV_TOKENIZER_FIM_REP_ID, LLM_KV_TOKENIZER_FIM_SEP_ID, + LLM_KV_TOKENIZER_SUPPRESS_TOKENS, LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, @@ -327,6 +338,10 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_TARGET_LAYERS, + LLM_KV_TARGET_HIDDEN_SIZE, + LLM_KV_NORM_BEFORE_RESIDUAL, + LLM_KV_SHORTCONV_L_CACHE, LLM_KV_XIELU_ALPHA_N, @@ -549,14 +564,21 @@ enum llm_tensor { LLM_TENSOR_INDEXER_PROJ, LLM_TENSOR_INDEXER_ATTN_K, LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_PROJ_PRE, + LLM_TENSOR_NEXTN_PROJ_POST, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_MASKED_EMBD_CENTROIDS, + LLM_TENSOR_MASKED_EMBD_ORDERING, + LLM_TENSOR_FC, + LLM_TENSOR_D2T, }; + enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index f10397747b0..6d822ec62d6 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -62,6 +62,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X }, { "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 }, + { "granite-4.1", LLM_CHAT_TEMPLATE_GRANITE_4_1 }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, @@ -194,7 +195,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { if (tmpl_contains("<tool_call>") || tmpl_contains("<tools>")) { - return LLM_CHAT_TEMPLATE_GRANITE_4_0; + if (tmpl_contains("g4_default_system_message")) { + return LLM_CHAT_TEMPLATE_GRANITE_4_0; + } + return LLM_CHAT_TEMPLATE_GRANITE_4_1; } return LLM_CHAT_TEMPLATE_GRANITE_3_X; } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { @@ -651,6 +655,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_1) { + // IBM Granite 4.1 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index ea6540c0be7..dc37f919a96 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -41,6 +41,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE_3_X, LLM_CHAT_TEMPLATE_GRANITE_4_0, + LLM_CHAT_TEMPLATE_GRANITE_4_1, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index ad36c06667d..168dbabd766 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -58,19 +58,24 @@ llama_context::llama_context( cparams.n_rs_seq = 0; } - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; - cparams.embeddings_pre_norm = false; - cparams.embeddings_pre_norm_masked = false; - cparams.offload_kqv = params.offload_kqv; - cparams.no_perf = params.no_perf; - cparams.pooling_type = params.pooling_type; - cparams.warmup = false; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; + cparams.embeddings = params.embeddings; + cparams.embeddings_nextn = false; + cparams.embeddings_nextn_masked = false; + cparams.offload_kqv = params.offload_kqv; + cparams.no_perf = params.no_perf; + cparams.warmup = false; + + cparams.embeddings_layer_inp.resize(hparams.n_layer(), false); + embd_layer_inp.resize(hparams.n_layer()); + + cparams.ctx_type = params.ctx_type; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -83,7 +88,26 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; - cparams.ctx_type = params.ctx_type; + cparams.ctx_other = nullptr; + + // TODO: more generic + if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { + if (params.ctx_other == nullptr) { + // TODO: change from runtime_error to llama_exception to avoid printing error message + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this warning is normal during memory fitting)"); + } + + cparams.ctx_other = params.ctx_other; + } + + if (model.arch == LLM_ARCH_EAGLE3) { + if (model.tok_embd == nullptr || model.output == nullptr) { + if (params.ctx_other == nullptr) { + throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)"); + } + cparams.ctx_other = params.ctx_other; + } + } // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later @@ -182,6 +206,8 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.n_outputs_max = params.n_outputs_max == 0 || llama_model_has_encoder(&model) ? cparams.n_batch : params.n_outputs_max; + cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; @@ -227,6 +253,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); + LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -296,10 +323,11 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, - /*.swa_full =*/ params.swa_full, - /*.ctx_type= */ cparams.ctx_type, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_other =*/ llama_get_memory(cparams.ctx_other), }; memory.reset(model.create_memory(params_mem, cparams)); @@ -337,7 +365,7 @@ llama_context::llama_context( // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = model.n_devices() > 1 && - model.n_gpu_layers() > model.hparams.n_layer && + model.n_gpu_layers() > model.hparams.n_layer_all && model.split_mode() == LLAMA_SPLIT_MODE_LAYER && cparams.offload_kqv && !model.has_tensor_overrides(); @@ -531,7 +559,7 @@ void llama_context::sched_reserve() { // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, - // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). + // the ggml_mul_mat assertion fails. const uint32_t n_tokens_ch = 16*n_seqs; auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); if (!gf) { @@ -577,16 +605,18 @@ void llama_context::sched_reserve() { int n_splits_tg = -1; int n_nodes_tg = -1; + const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max); + // reserve pp (prompt processing) graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); if (!gf) { if (cparams.pipeline_parallel) { LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); cparams.pipeline_parallel = false; sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get()); } if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -614,7 +644,7 @@ void llama_context::sched_reserve() { // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -774,7 +804,9 @@ bool llama_context::memory_update(bool optimize) { const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + const uint32_t n_outputs_max = std::min(n_tokens, cparams.n_outputs_max); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_max, mctx.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } @@ -882,34 +914,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } -float * llama_context::get_embeddings_pre_norm() { +float * llama_context::get_embeddings_nextn() { output_reorder(); - return embd_pre_norm.data; + return embd_nextn.data; } -float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { +float * llama_context::get_embeddings_nextn_ith(int32_t i) { output_reorder(); try { - if (embd_pre_norm.data == nullptr) { - throw std::runtime_error("no pre-norm embeddings"); + if (embd_nextn.data == nullptr) { + throw std::runtime_error("no nextn embeddings"); } - const uint32_t n_embd = model.hparams.n_embd; + const uint32_t n_embd = model.hparams.n_embd_out(); - if (!cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm rows are stored densely, indexed by raw token position. - if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { - throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + if (!cparams.embeddings_nextn_masked) { + // unmasked: nextn rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd)); } - return embd_pre_norm.data + (size_t) i * n_embd; + return embd_nextn.data + (size_t) i * n_embd; } const int64_t j = output_resolve_row(i); - return embd_pre_norm.data + j*n_embd; + return embd_nextn.data + j*n_embd; } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); + LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG GGML_ABORT("fatal error"); #else @@ -918,6 +950,14 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { } } +float * llama_context::get_embeddings_layer_inp(uint32_t lid) { + output_reorder(); + + GGML_ASSERT(lid < embd_layer_inp.size() && embd_layer_inp[lid].has_data()); + + return embd_layer_inp[lid].data; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1098,11 +1138,22 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } -void llama_context::set_embeddings_pre_norm(bool value, bool masked) { +void llama_context::set_embeddings_nextn(bool value, bool masked) { LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); - cparams.embeddings_pre_norm = value; - cparams.embeddings_pre_norm_masked = masked; + cparams.embeddings_nextn = value; + cparams.embeddings_nextn_masked = masked; +} + +void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { + LLAMA_LOG_DEBUG("%s: lid = %d, enable = %d\n", __func__, lid, enable); + + GGML_ASSERT(lid < model.hparams.n_layer()); + + cparams.embeddings_layer_inp[lid] = enable; + + // note: without this reserve, the draft acceptance drops to zero. not sure why - this is unexpected + sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { @@ -1319,7 +1370,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), // so accept either present rather than requiring exactly one. GGML_ASSERT(batch_inp.token || batch_inp.embd); @@ -1330,7 +1381,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd_inp(); + // eagle3/DFlash: features as encoder input, and non-draft paths fall back to model's input dim + const int64_t n_embd = hparams.n_embd_inp(); const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -1392,9 +1444,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); - auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1460,14 +1512,14 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - // extract pre-norm embeddings (hidden state before the final output norm) - if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + // extract nextn embeddings (hidden state before the final output norm) + if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + const uint32_t n_embd = hparams.n_embd_out(); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float)); } // TODO: hacky solution @@ -1622,7 +1674,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s } int llama_context::decode(const llama_batch & batch_inp) { - // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), // so accept either present rather than requiring exactly one. GGML_ASSERT(batch_inp.token || batch_inp.embd); @@ -1822,9 +1874,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1905,22 +1957,24 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // extract pre-norm embeddings (hidden state before the final output norm) + extract_layer_inputs(res, n_tokens_prev, ubatch.n_tokens); + + // extract nextn embeddings before // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. { - const bool masked = cparams.embeddings_pre_norm_masked; + const bool masked = cparams.embeddings_nextn_masked; const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; - if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + const uint32_t n_embd = hparams.n_embd_out(); + float * embd_nextn_out = embd_nextn.data + offset*n_embd; - GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float)); } } @@ -2012,9 +2066,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; - bool has_embd_pre_norm = cparams.embeddings_pre_norm; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_nextn = cparams.embeddings_nextn; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -2022,18 +2076,24 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - size_t backend_float_count = 0; size_t backend_token_count = 0; + size_t embd_layer_inp_float_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; - embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0; - if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm row exists for every token in the batch, not just + if (has_embd_nextn && !cparams.embeddings_nextn_masked) { + // unmasked: nextn row exists for every token in the batch, not just // those flagged via batch.logits[i] -> size by token count instead. - embd_pre_norm.size = (size_t) n_embd * n_batch; + embd_nextn.size = (size_t) n_embd_out * n_batch; + } + + for (bool enabled : cparams.embeddings_layer_inp) { + if (enabled) { + embd_layer_inp_float_count += (size_t) n_embd * n_batch; + } } // Allocate backend sampling output buffers if there are backend samplers configured. @@ -2050,8 +2110,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_nextn.size + embd_layer_inp_float_count + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -2067,7 +2127,10 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; - embd_pre_norm.data = nullptr; + embd_nextn.data = nullptr; + for (auto & layer_inp : embd_layer_inp) { + layer_inp = {nullptr, 0}; + } } auto * buft = ggml_backend_cpu_buffer_type(); @@ -2096,8 +2159,17 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0}; offset += embd.size * sizeof(float); - embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0}; - offset += embd_pre_norm.size * sizeof(float); + embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0}; + offset += embd_nextn.size * sizeof(float); + + for (uint32_t il = 0; il < embd_layer_inp.size(); ++il) { + if (cparams.embeddings_layer_inp[il]) { + embd_layer_inp[il] = buffer_view<float>{(float *) (base + offset), (size_t) n_embd * n_batch}; + offset += embd_layer_inp[il].size * sizeof(float); + } else { + embd_layer_inp[il] = buffer_view<float>{nullptr, 0}; + } + } if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; @@ -2140,9 +2212,39 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { this->n_outputs = 0; + GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max); + return n_outputs_max; } +void llama_context::extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens) { + for (uint32_t il = 0; il < cparams.embeddings_layer_inp.size(); ++il) { + if (!cparams.embeddings_layer_inp[il]) { + continue; + } + if (!embd_layer_inp[il].has_data()) { + GGML_ABORT("output layer input buffer not allocated"); + } + ggml_tensor * t = res->get_layer_inp((int) il); + if (!t) { + GGML_ABORT("layer input tensor not found"); + } + + const size_t nbytes = ggml_nbytes(t); + const size_t nfloats = nbytes / sizeof(float); + GGML_ASSERT(n_tokens > 0); + GGML_ASSERT(nfloats % n_tokens == 0); + + const size_t row_floats = nfloats / n_tokens; + const size_t dst_offset = token_offset * row_floats; + GGML_ASSERT(dst_offset + nfloats <= embd_layer_inp[il].size); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t); + GGML_ASSERT(backend != nullptr); + ggml_backend_tensor_get_async(backend, t, embd_layer_inp[il].data + dst_offset, 0, nbytes); + } +} + void llama_context::output_reorder() { const uint64_t n_vocab = model.vocab.n_tokens(); const uint64_t n_embd = model.hparams.n_embd; @@ -2163,9 +2265,19 @@ void llama_context::output_reorder() { } } - if (embd_pre_norm.size > 0) { + if (embd_nextn.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]); + } + } + + if (embd_layer_inp.size() > 0) { + for (int lid = 0; lid < (int) embd_layer_inp.size(); ++lid) { + if (embd_layer_inp[lid].size > 0) { + for (uint64_t k = 0; k < n_embd; ++k) { + std::swap(embd_layer_inp[lid].data[i0*n_embd + k], embd_layer_inp[lid].data[i1*n_embd + k]); + } + } } } @@ -2226,8 +2338,6 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::max(n_outputs, n_tokens); - LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } @@ -2343,7 +2453,7 @@ llm_graph_cb llama_context::graph_get_cb() const { // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched - const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; + const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer_all; if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { const auto & dev_layer = model.dev_layer(il); @@ -3337,6 +3447,7 @@ llama_context_params llama_context_default_params() { /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, /*.n_rs_seq =*/ 0, + /*.n_outputs_max =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, @@ -3366,6 +3477,7 @@ llama_context_params llama_context_default_params() { /*.kv_unified =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.ctx_other =*/ nullptr, }; return result; @@ -3403,15 +3515,11 @@ llama_context * llama_init_from_model( LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); return nullptr; } - if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { - LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); - return nullptr; - } } if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); - for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { if (model->hparams.n_embd_head_k(il) % blck_size != 0) { LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); @@ -3422,7 +3530,7 @@ llama_context * llama_init_from_model( if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); - for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { if (model->hparams.n_embd_head_v(il) % blck_size != 0) { LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); @@ -3444,12 +3552,11 @@ llama_context * llama_init_from_model( } if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && - model->hparams.nextn_predict_layers == 0) { + model->hparams.n_layer_nextn == 0) { LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); return nullptr; } - try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3584,20 +3691,38 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { - ctx->set_embeddings_pre_norm(value, masked); +void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_nextn(value, masked); +} + +void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool value) { + ctx->set_embeddings_layer_inp(lid, value); +} + +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + if (!ctx) { + return nullptr; + } + + return ctx->get_memory(); +} + +float * llama_get_embeddings_nextn(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_nextn(); } -float * llama_get_embeddings_pre_norm(llama_context * ctx) { +float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return ctx->get_embeddings_pre_norm(); + return ctx->get_embeddings_nextn_ith(i); } -float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { +float * llama_get_embeddings_layer_inp(llama_context * ctx, uint32_t lid) { ctx->synchronize(); - return ctx->get_embeddings_pre_norm_ith(i); + return ctx->get_embeddings_layer_inp(lid); } bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { @@ -3651,7 +3776,7 @@ struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { - auto * memory = ctx->get_memory(); + auto memory = ctx->get_memory(); llama_memory_context_ptr mctx; if (memory) { mctx = memory->init_full(); @@ -3691,10 +3816,6 @@ int32_t llama_set_adapter_cvec( // memory // -llama_memory_t llama_get_memory(const struct llama_context * ctx) { - return ctx->get_memory(); -} - void llama_memory_clear(llama_memory_t mem, bool data) { if (!mem) { return; @@ -4005,3 +4126,7 @@ void llama_opt_epoch( llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { return ctx->memory_breakdown(); } + +llama_context * llama_get_ctx_other(struct llama_context * ctx) { + return ctx->get_cparams().ctx_other; +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index d03f681d4a1..853052be2ca 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -6,6 +6,7 @@ #include "llama-graph.h" #include "llama-adapter.h" #include "llama-impl.h" +#include "llama-memory.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -84,8 +85,10 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); - float * get_embeddings_pre_norm(); - float * get_embeddings_pre_norm_ith(int32_t i); + float * get_embeddings_nextn(); + float * get_embeddings_nextn_ith(int32_t i); + + float * get_embeddings_layer_inp(uint32_t lid); llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -110,7 +113,8 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); - void set_embeddings_pre_norm(bool value, bool masked); + void set_embeddings_nextn(bool value, bool masked); + void set_embeddings_layer_inp(uint32_t lid, bool enable); void set_causal_attn(bool value); void set_warmup(bool value); @@ -225,6 +229,10 @@ struct llama_context { // map the output row index `i` to batch index int64_t output_resolve_row(int32_t i) const; + // async-copy enabled layer-input tensors (per cparams.output_layer_inp) + // from backend into host-side embd_layer_inp buffers + void extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens); + // // graph // @@ -273,7 +281,7 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr<llama_memory_i> memory; + llama_memory_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) buffer_view<float> logits = {nullptr, 0}; @@ -282,10 +290,14 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view<float> embd = {nullptr, 0}; - // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) - // populated only when cparams.embeddings_pre_norm is enabled and the model graph - // sets llm_graph_result::t_h_pre_norm - buffer_view<float> embd_pre_norm = {nullptr, 0}; + // hidden state required by the nextn layers (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_nextn is enabled and the model graph + // sets llm_graph_result::t_h_nextn + buffer_view<float> embd_nextn = {nullptr, 0}; + + // host buffers for output layer input embeddings, per layer + // populated when cparams.output_layer_inp[il] is true + std::vector<buffer_view<float>> embd_layer_inp; struct sampling_info { // !samplers.empty() to check if any samplers are active diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 20ec59fe335..2b109f909c0 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include <cstdint> +#include <vector> #define LLAMA_MAX_SEQ 256 @@ -13,6 +14,7 @@ struct llama_cparams { uint32_t n_ubatch; uint32_t n_seq_max; uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback + uint32_t n_outputs_max; // max outputs supported by the context int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -28,8 +30,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; - bool embeddings_pre_norm; // also extract the hidden state before the final output norm - bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 + bool embeddings_nextn; // also extract the hidden state before the final output norm + bool embeddings_nextn_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -38,14 +40,18 @@ struct llama_cparams { bool fused_gdn_ch; // use fused gated delta net (chunked) bool auto_fgdn; bool no_perf; - bool warmup; + bool warmup; // TODO: remove [TAG_LLAMA_GRAPH_NO_WARMUP] bool op_offload; bool kv_unified; bool pipeline_parallel; + std::vector<bool> embeddings_layer_inp; // [n_layer()] extract input embeddings for layer + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_context * ctx_other; }; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index edfa71c207c..b744af52864 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -89,18 +89,32 @@ LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * m LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); -// -// pre-norm embeddings (hidden state before the final output norm) -// - -// Set whether the context outputs pre-norm embeddings or not +// Set whether the context outputs nextn embeddings or not // If masked == true, output the embeddings only for the tokens with batch.logits != 0 // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits -LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); +LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); -LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); -LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); + +// Set whether the context outputs the input embeddings of a specific layer +LLAMA_API void llama_set_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid); + +LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); + +// +// model/context data extraction +// + +// returns pointer to the target-model layer indices +LLAMA_API const int32_t * llama_model_target_layer_ids (const struct llama_model * model); +// returns the number of extracted layers from target model +LLAMA_API uint32_t llama_model_target_layer_ids_n(const struct llama_model * model); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fc027de8b39..7468bd9b79e 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -29,7 +30,10 @@ static ggml_tensor * build_attn_inp_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + // flash attention requires an f16 mask + const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(res); ggml_set_name(res, "attn_inp_kq_mask"); @@ -102,6 +106,39 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) { + const int64_t n_tokens = ubatch->n_tokens; + + if (ubatch->token) { + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + } else { + // note: mtmd embedding input goes through here + GGML_ASSERT(ubatch->embd); + GGML_ASSERT(n_embd == embd->ne[0]); + + ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } + + // TODO: extend llama_ubatch to differentiate between token embeddings and hidden states + // for now, we assume that the hidden state is always provided as an embedding + // ref: https://github.com/ggml-org/llama.cpp/pull/23643 + if (ubatch->embd) { + GGML_ASSERT(n_embd == h->ne[0]); + + ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } +} + +bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (h && h->ne[1] == params.ubatch.n_tokens); + + return res; +} + void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && pos) { const int64_t n_tokens = ubatch->n_tokens; @@ -348,7 +385,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +template <typename T> +static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); const char * swa_type_str = "unknown"; @@ -359,7 +397,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; }; - LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -372,7 +410,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { LLAMA_LOG_DEBUG(" %2d ", i); for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { - float val = data[i * n_kv + j]; + float val = llama_cast<float>(data[i * n_kv + j]); if (val == -INFINITY) { LLAMA_LOG_DEBUG(" ∞"); } else { @@ -387,7 +425,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { + const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) { + using T = std::remove_reference_t<decltype(*data)>; + std::fill(data, data + ne, llama_cast<T>(-INFINITY)); + for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; const llama_pos p1 = ubatch->pos[i1]; @@ -413,38 +454,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { continue; } - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f); } } - }; - - { - GGML_ASSERT(self_kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - - std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); - - fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); if (debug) { - print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); + print_mask(data, n_tokens, n_kv, n_swa, swa_type); } + }; + + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + if (self_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + } else { + fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(self_kq_mask_swa); GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - - float * data = (float *) self_kq_mask_swa->data; - - std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); - - fill_mask(data, hparams.n_swa, hparams.swa_type); - - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + if (self_kq_mask_swa->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + } else { + fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); } } } @@ -499,12 +532,43 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) { + mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch); + + mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch); + + mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_rot(self_k_rot_lid); +} + +bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast<const llama_kv_cache_dsa_context *>(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens; + res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams); + + return res; +} + void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { // base tensors may not be allocated if there are no non-SWA attention layers if (self_k_idxs && self_k_idxs->buffer) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + } + // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live + if (self_kq_mask && self_kq_mask->buffer) { mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } @@ -512,7 +576,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { if (self_k_idxs_swa && self_k_idxs_swa->buffer) { mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + } + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } @@ -544,7 +610,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { if (self_k_idxs && self_k_idxs->buffer) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + } + if (self_kq_mask && self_kq_mask->buffer) { res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); } @@ -552,7 +620,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { if (self_k_idxs_swa && self_k_idxs_swa->buffer) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + } + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); } @@ -568,23 +638,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing - float * data = (float *) cross_kq_mask->data; + const auto fill_mask = [&](auto * data) { + using T = std::remove_reference_t<decltype(*data)>; + for (int i = 0; i < n_tokens; ++i) { + GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int i = 0; i < n_tokens; ++i) { - GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; - - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; + } } - } - data[i*n_enc + j] = f; + data[i*n_enc + j] = llama_cast<T>(f); + } } + }; + + if (cross_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) cross_kq_mask->data); + } else { + fill_mask((float *) cross_kq_mask->data); } } @@ -688,7 +765,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + } + if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) { attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); } @@ -696,7 +775,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + } + if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } @@ -742,18 +823,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; @@ -823,6 +904,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + + t_layer_inp.resize(LLAMA_MAX_LAYERS); + std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr); + t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -851,7 +936,7 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } -void llm_graph_result::set_outputs() { +void llm_graph_result::set_outputs(const llm_graph_params & params) { if (t_logits != nullptr) { ggml_set_output(t_logits); } @@ -861,8 +946,17 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } - if (t_h_pre_norm != nullptr) { - ggml_set_output(t_h_pre_norm); + if (t_h_nextn != nullptr) { + ggml_set_output(t_h_nextn); + } + { + const auto & embeddings_layer_inp = params.cparams.embeddings_layer_inp; + for (size_t il = 0; il < embeddings_layer_inp.size(); ++il) { + if (embeddings_layer_inp[il]) { + GGML_ASSERT(t_layer_inp[il] != nullptr && "layer input tensor is null"); + ggml_set_output(t_layer_inp[il]); + } + } } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { @@ -937,7 +1031,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : cparams (params.cparams), ubatch (params.ubatch), n_embd (hparams.n_embd), - n_layer (hparams.n_layer), + n_layer (hparams.n_layer()), + n_layer_nextn (hparams.n_layer_nextn), n_rot (hparams.n_rot()), n_ctx (cparams.n_ctx), n_head (hparams.n_head()), @@ -1791,7 +1886,12 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { res->t_inp_embd = cur; // For Granite architecture - if (hparams.f_embedding_scale != 0.0f) { + // NOTE: For deepstack models, only apply scale to token inputs (ie text-only input). + // Raw embeddings are assumed to be multimodal inputs that should not be scaled. + if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) { + if (!ggml_is_contiguous(cur)) { + cur = ggml_cont(ctx0, cur); + } cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } @@ -2088,17 +2188,20 @@ ggml_tensor * llm_graph_context::build_attn_mha( llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } else { inp->self_kq_mask_swa = nullptr; inp->self_kq_mask_swa_cnv = nullptr; @@ -2175,7 +2278,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl( inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); @@ -2282,7 +2385,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } return inp; @@ -2354,6 +2457,82 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + ggml_tensor * top_k, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx->get_mla(); + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs_mla(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask_mla(); + + // prepare new kq mask - starts filled with -INFINITY + ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); + + // reshape KQ mask into tensor with rows of size 1: + // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream] + kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0); + + // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1] + ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0); + + // prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream] + // this will be our source of zero values for unmasking top k mask elements + ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]); + zeros = ggml_fill(ctx0, zeros, 0.0f); + + // modify KQ mask by unmasking elements that are in top_k indices + // ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1]) + ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d); + + // reshape to restore the original shape of KQ mask: + // [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream] + kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0); + + // combine with the original kq mask + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur, wo_s); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -2446,10 +2625,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1); ggml_set_input(inp->cross_kq_mask); - inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + inp->cross_kq_mask_cnv = inp->cross_kq_mask; return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); } @@ -2497,6 +2679,34 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const { + const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx); + + auto inp = std::make_unique<llm_graph_input_attn_k_dsa>(hparams, cparams, mctx_cur); + + { + inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams); + inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla; + } + + { + inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch); + + // ensure F32 mask + auto cparams_copy = cparams; + cparams_copy.flash_attn = false; + + inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy); + inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid; + + inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0); + } + + return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp)); +} + // TODO: maybe separate the inner implementation into a separate function // like with the non-sliding window equivalent // once sliding-window hybrid caches are a thing. @@ -2510,7 +2720,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } { @@ -2520,7 +2730,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); @@ -2689,7 +2899,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask; } { @@ -2697,7 +2907,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa; } auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index bf6778237e6..cc5cfe51dcd 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -22,6 +22,7 @@ struct llama_layer; struct llama_memory_context_i; class llama_kv_cache_context; +class llama_kv_cache_dsa_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -35,7 +36,8 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DECODER_MTP, }; -enum llm_ffn_op_type { +enum llm_ffn_op_type : int { + LLM_FFN_NONE = 0, // sentinel: unset; archs must assign before use LLM_FFN_SILU, LLM_FFN_GELU, LLM_FFN_RELU, @@ -121,6 +123,23 @@ class llm_graph_input_embd : public llm_graph_input_i { const int64_t n_embd = 0; }; +// similar to llm_graph_input_embd but with an additional hidden state input +class llm_graph_input_embd_h : public llm_graph_input_i { +public: + llm_graph_input_embd_h(int64_t n_embd) : n_embd(n_embd) {} + virtual ~llm_graph_input_embd_h() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * tokens = nullptr; // I32 [n_batch] + ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + ggml_tensor * h = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; +}; + class llm_graph_input_pos : public llm_graph_input_i { public: llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} @@ -274,10 +293,10 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } // n_tokens == n_batch - ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -307,8 +326,8 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] // note: assumes v_rot^2 == I ggml_tensor * self_k_rot = nullptr; @@ -347,8 +366,8 @@ class llm_graph_input_attn_k : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -356,6 +375,44 @@ class llm_graph_input_attn_k : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +class llm_graph_input_attn_k_dsa : public llm_graph_input_i { +public: + llm_graph_input_attn_k_dsa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_dsa_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k_dsa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs_mla() const { return self_k_idxs_mla; } + ggml_tensor * get_k_idxs_lid() const { return self_k_idxs_lid; } + + ggml_tensor * get_kq_mask_mla() const { return self_kq_mask_mla_cnv; } + ggml_tensor * get_kq_mask_lid() const { return self_kq_mask_lid; } + + ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch] + ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask_mla = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * self_k_rot_lid = nullptr; + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_dsa_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -385,10 +442,10 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_k_rot = nullptr; ggml_tensor * self_v_rot = nullptr; @@ -411,8 +468,8 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; @@ -646,7 +703,9 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } - ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } + ggml_tensor * get_h_nextn() const { return t_h_nextn; } + + ggml_tensor * get_layer_inp(int il) const { return t_layer_inp[il]; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -656,7 +715,7 @@ class llm_graph_result { void reset(); void set_inputs(const llama_ubatch * ubatch); - void set_outputs(); + void set_outputs(const llm_graph_params & params); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -675,12 +734,14 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; - ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm + ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm - std::map<llama_seq_id, ggml_tensor*> t_sampled_logits; - std::map<llama_seq_id, ggml_tensor*> t_candidates; - std::map<llama_seq_id, ggml_tensor*> t_sampled; - std::map<llama_seq_id, ggml_tensor*> t_sampled_probs; + std::vector<ggml_tensor *> t_layer_inp; + + std::map<llama_seq_id, ggml_tensor *> t_sampled_logits; + std::map<llama_seq_id, ggml_tensor *> t_candidates; + std::map<llama_seq_id, ggml_tensor *> t_sampled; + std::map<llama_seq_id, ggml_tensor *> t_sampled_probs; std::vector<llm_graph_input_ptr> inputs; @@ -727,6 +788,7 @@ struct llm_graph_context { const int64_t n_embd; const int64_t n_layer; + const int64_t n_layer_nextn; const int64_t n_rot; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) const int64_t n_head; @@ -956,6 +1018,23 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_k_dsa * build_attn_inp_k_dsa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + ggml_tensor * top_k, // [n_indexer_top_k, n_tokens] + float kq_scale, + int il) const; + llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 2239309c8fb..0be045707a4 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -7,19 +7,39 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern != 0); } } else { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); } } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_swa_impl[il] = false; + } +} + +void llama_hparams::set_recr_pattern(uint32_t n_pattern, bool dense_first) { + if (dense_first) { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0); + } + } else { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } + } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_recr_impl[il] = false; + } } bool llama_hparams::is_swa_any() const { - for (uint32_t il = 0; il < n_layer; ++il) { - if (swa_layers[il]) { + for (uint32_t il = 0; il < n_layer_all; ++il) { + if (is_swa_impl[il]) { return true; } } @@ -28,7 +48,7 @@ bool llama_hparams::is_swa_any() const { } uint32_t llama_hparams::n_head(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_arr[il]; } @@ -36,7 +56,7 @@ uint32_t llama_hparams::n_head(uint32_t il) const { } uint32_t llama_hparams::n_head_kv(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_kv_arr[il]; } @@ -44,7 +64,7 @@ uint32_t llama_hparams::n_head_kv(uint32_t il) const { } uint32_t llama_hparams::n_ff(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_ff_arr[il]; } @@ -63,7 +83,7 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const { } uint32_t llama_hparams::n_rot(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_rot_swa : n_rot_full; } @@ -71,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const { } uint32_t llama_hparams::n_embd_inp() const { + if (n_embd_inp_impl > 0) { + return n_embd_inp_impl; + } + uint32_t n_embd_inp = n_embd; if (n_deepstack_layers > 0) { @@ -85,7 +109,7 @@ uint32_t llama_hparams::n_embd_out() const { } uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full; } @@ -93,7 +117,7 @@ uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { } uint32_t llama_hparams::n_embd_head_v(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full; } @@ -114,7 +138,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { bool llama_hparams::is_n_embd_k_gqa_variable() const { const uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_k_gqa(il)) { return true; } @@ -125,7 +149,7 @@ bool llama_hparams::is_n_embd_k_gqa_variable() const { bool llama_hparams::is_n_embd_v_gqa_variable() const { const uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_v_gqa(il)) { return true; } @@ -136,7 +160,7 @@ bool llama_hparams::is_n_embd_v_gqa_variable() const { uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_k_gqa(il)); } @@ -145,7 +169,7 @@ uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t llama_hparams::n_embd_v_gqa_max() const { uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_v_gqa(il)); } @@ -193,12 +217,12 @@ uint32_t llama_hparams::n_embd_s() const { return ssm_d_state * ssm_d_inner; } -bool llama_hparams::is_recurrent(uint32_t il) const { - if (il < n_layer) { - return recurrent_layer_arr[il]; +bool llama_hparams::is_recr(uint32_t il) const { + if (il < n_layer_all) { + return is_recr_impl[il]; } - GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); } uint32_t llama_hparams::n_pos_per_embd() const { @@ -206,11 +230,11 @@ uint32_t llama_hparams::n_pos_per_embd() const { } bool llama_hparams::is_swa(uint32_t il) const { - if (il < n_layer) { - return swa_layers[il]; + if (il < n_layer_all) { + return is_swa_impl[il]; } - GGML_ABORT("fatal error"); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); } bool llama_hparams::is_mla() const { @@ -247,16 +271,8 @@ bool llama_hparams::has_kv(uint32_t il) const { return true; } -uint32_t llama_hparams::n_layer_kv() const { - uint32_t res = 0; - - for (uint32_t il = 0; il < n_layer; ++il) { - if (has_kv(il)) { - res++; - } - } - - return res; +uint32_t llama_hparams::n_layer() const { + return n_layer_all - n_layer_nextn; } bool llama_hparams::use_mrope() const { diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index e2d051edc6c..617aa5c9664 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -23,6 +23,9 @@ enum llama_swa_type { LLAMA_SWA_TYPE_SYMMETRIC = 3, }; +// forward declaration; full definition in llama-graph.h +enum llm_ffn_op_type : int; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer; @@ -34,20 +37,27 @@ struct llama_hparams_convnext { }; struct llama_hparams { + // note: use the `_impl` suffix to avoid name conflict between members and getters + // for example: n_embd_out() vs n_embd_out_impl + bool vocab_only; bool no_alloc; bool rope_finetuned; bool use_par_res; bool swin_norm; + bool norm_before_residual = false; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + uint32_t n_layer_all; + uint32_t n_layer_nextn = 0; uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // TODO: this needs to be reworked + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + // different head size for full_attention and SWA layers uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head @@ -90,7 +100,6 @@ struct llama_hparams { uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; uint32_t moe_latent_size = 0; - uint32_t nextn_predict_layers = 0; bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) @@ -134,11 +143,15 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == 1, then layer il is SWA - // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) + + // if is_swa_impl[il] == 1, then layer il is SWA + // if is_swa_impl[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense // note: using uint32_t type for compatibility reason - std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers; + std::array<uint32_t, LLAMA_MAX_LAYERS> is_swa_impl; + + // for hybrid state space models + std::array<uint32_t, LLAMA_MAX_LAYERS> is_recr_impl; // for State Space Models uint32_t ssm_d_conv = 0; @@ -150,9 +163,6 @@ struct llama_hparams { // for Kimi Linear KDA uint32_t n_embd_head_kda = 0; - // for hybrid state space models - std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr; - bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -178,6 +188,9 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // input embedding dimension (0 = use n_embd) + uint32_t n_embd_inp_impl = 0; + // output embedding dimension (0 = use n_embd) uint32_t n_embd_out_impl = 0; @@ -212,8 +225,19 @@ struct llama_hparams { uint32_t indexer_top_k = 0; // qwen3vl deepstack + // When parsed from GGUF, this implies the first N layers consume the first + // N deepstack embeddings. Use deepstack_mapping_arr if you need a more + // complex mapping. If using deepstack_mapping_arr, also make sure to set + // n_deepstack_layers to the number of unique deepstack layers so that + // n_embd_imp is accurate (see granite.cpp). + // TODO: can be expressed via the `new n_embd_inp_impl` and remove this param uint32_t n_deepstack_layers = 0; + // deepstack layer array (Granite4 Vision) + // -1 => no deepstack + // >=0 => input embedding index for deepstack injection + std::array<int32_t, LLAMA_MAX_LAYERS> deepstack_mapping_arr; + // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; @@ -227,6 +251,14 @@ struct llama_hparams { enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // Resolved FFN gated activation flavor for archs that read + // `<arch>.hidden_activation` from the GGUF (e.g. ModernBert derivatives). + // Defaults to LLM_FFN_NONE (sentinel = 0); the mapping from the GGUF + // string to a real op is done at hparam-load time via + // llm_ffn_op_type_from_string() in llama-model.cpp, mirroring how + // rope_scaling_type_train is handled. + enum llm_ffn_op_type llm_ffn_op; + // Step35: optional per-layer clamps for (Swi)GLU std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_exp; // clamping for expert FFN std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_shexp; // shared expert @@ -255,6 +287,13 @@ struct llama_hparams { // return true if one of the layers is SWA bool is_swa_any() const; + bool is_swa(uint32_t il) const; + + void set_recr_pattern(uint32_t n_pattern, bool dense_first = false); + + // whether or not the given layer is recurrent (for hybrid models) + bool is_recr(uint32_t il) const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; @@ -296,13 +335,8 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_s() const; - // whether or not the given layer is recurrent (for hybrid models) - bool is_recurrent(uint32_t il) const; - uint32_t n_pos_per_embd() const; - bool is_swa(uint32_t il) const; - // note: currently only support if either all or none of the layers are MLA bool is_mla() const; @@ -311,8 +345,8 @@ struct llama_hparams { bool has_kv(uint32_t il) const; - // number of layers for which has_kv() returns true - uint32_t n_layer_kv() const; + // number of effective layers (excludes nextn layers) + uint32_t n_layer() const; // note that this function uses different SWA parameters from those in the hparams // note: inlined on purpose for performance reasons diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index e4f35c8e53d..7923c3f7ed5 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -3,6 +3,7 @@ #include "ggml.h" // for ggml_log_level #include <string> +#include <type_traits> #include <vector> #ifdef __GNUC__ @@ -40,6 +41,19 @@ struct no_init { no_init() = default; }; +template <typename dst_t, typename src_t> +static inline dst_t llama_cast(src_t v) { + if constexpr (std::is_same_v<src_t, dst_t>) { + return v; + } else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) { + return ggml_fp16_to_fp32(v); + } else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) { + return ggml_fp32_to_fp16(v); + } else { + static_assert(std::is_same_v<dst_t, void>, "unsupported type combination"); + } +} + struct time_meas { time_meas(int64_t & t_acc, bool disable = false); ~time_meas(); diff --git a/examples/talk-llama/llama-kv-cache-dsa.cpp b/examples/talk-llama/llama-kv-cache-dsa.cpp new file mode 100644 index 00000000000..916ab653756 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.cpp @@ -0,0 +1,261 @@ +#include "llama-kv-cache-dsa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include <algorithm> +#include <cassert> + +// +// llama_kv_cache_dsa +// + +llama_kv_cache_dsa::llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : + hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) { + + LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size); + + kv_mla = std::make_unique<llama_kv_cache>( + model, model.hparams, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); + + // we use llama_kv_cache for caching indexer keys + // by hand-tweaking some hparams we fool it to create + // indexer key cache tensors with correct dimensions + // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823 + + // DSA lightning indexer uses MQA with single key head + std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); + hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; + hparams_lid.rope_type = LLAMA_ROPE_TYPE_NEOX; + + LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size); + + kv_lid = std::make_unique<llama_kv_cache>( + model, hparams_lid, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); +} + +void llama_kv_cache_dsa::clear(bool data) { + kv_mla->clear(data); + kv_lid->clear(data); +} + +bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_mla->seq_rm(seq_id, p0, p1); + res = res & kv_lid->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) { + kv_mla->seq_keep(seq_id); + kv_lid->seq_keep(seq_id); +} + +void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_mla->seq_add(seq_id, p0, p1, shift); + kv_lid->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_mla->seq_div(seq_id, p0, p1, d); + kv_lid->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const { + return kv_mla->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const { + return kv_mla->seq_pos_max(seq_id); +} + +std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const { + std::map<ggml_backend_buffer_type_t, size_t> mb = kv_mla->memory_breakdown(); + for (const auto & buft_size : kv_lid->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + do { + balloc.split_reset(); + + std::vector<llama_ubatch> ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + auto sinfos_mla = kv_mla->prepare(ubatches); + if (sinfos_mla.empty()) { + break; + } + + auto sinfos_lid = kv_lid->prepare(ubatches); + if (sinfos_lid.empty()) { + break; + } + + assert(sinfos_mla.size() == sinfos_lid.size()); + + return std::make_unique<llama_kv_cache_dsa_context>( + this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches)); + } while (false); + + return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_full() { + return std::make_unique<llama_kv_cache_dsa_context>(this); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize); +} + +bool llama_kv_cache_dsa::get_can_shift() const { + return kv_mla->get_can_shift() && + kv_lid->get_can_shift() && + kv_mla->get_size() == kv_lid->get_size(); +} + +void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + kv_mla->state_write(io, seq_id, flags); + kv_lid->state_write(io, seq_id, flags); +} + +void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + kv_mla->state_read(io, seq_id, flags); + kv_lid->state_read(io, seq_id, flags); +} + +llama_kv_cache * llama_kv_cache_dsa::get_mla() const { + return kv_mla.get(); +} + +llama_kv_cache * llama_kv_cache_dsa::get_lid() const { + return kv_lid.get(); +} + +// +// llama_kv_cache_dsa_context +// + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv) : + ctx_mla(kv->get_mla()->init_full()), + ctx_lid(kv->get_lid()->init_full()), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize) : + ctx_mla(kv->get_mla()->init_update(lctx, optimize)), + ctx_lid(kv->get_lid()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_mla, + slot_info_vec_t sinfos_lid, + std::vector<llama_ubatch> ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)), + ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default; + +bool llama_kv_cache_dsa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_mla->next(); + ctx_lid->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_dsa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_mla->apply(); + res = res & ctx_lid->apply(); + + return res; +} + +llama_memory_status llama_kv_cache_dsa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_mla.get()); +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_lid.get()); +} diff --git a/examples/talk-llama/llama-kv-cache-dsa.h b/examples/talk-llama/llama-kv-cache-dsa.h new file mode 100644 index 00000000000..e2b330993b8 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.h @@ -0,0 +1,138 @@ +#pragma once + +#include "llama-kv-cache.h" + +#include <vector> + +// +// llama_kv_cache_dsa +// + +// utilizes two instances of llama_kv_cache: +// - the first instance is for caching key tensors of the model, +// - the second instance is for caching lightning indexer key tensors + +class llama_kv_cache_dsa : public llama_memory_i { +public: + llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache_dsa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_kv_cache_dsa specific API + // + + llama_kv_cache * get_mla() const; + llama_kv_cache * get_lid() const; + +private: + // we keep indexer KV cache hparams instance here as llama_kv_cache stores only reference to it + llama_hparams hparams_lid; + const uint32_t n_stream = 1; + + std::unique_ptr<llama_kv_cache> kv_mla; + std::unique_ptr<llama_kv_cache> kv_lid; +}; + +class llama_kv_cache_dsa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // used for errors + llama_kv_cache_dsa_context(llama_memory_status status); + + // used to create a full-cache context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv); + + // used to create an update context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize); + + // used to create a batch processing context from a batch + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_ik, + std::vector<llama_ubatch> ubatches); + + virtual ~llama_kv_cache_dsa_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_dsa_context specific API + // + + const llama_kv_cache_context * get_mla() const; + const llama_kv_cache_context * get_lid() const; + +private: + //llama_kv_cache_dsa * kv; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector<llama_ubatch> ubatches; + + const llama_memory_context_ptr ctx_mla; + const llama_memory_context_ptr ctx_lid; + + const llama_memory_status status; +}; diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 26e2cb4270b..aa1b1b72ebe 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : hparams(model.hparams), unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { @@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + llama_memory_t mem_other_base = nullptr; + if (mem_other) { + mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base(); + } + + llama_memory_t mem_other_swa = nullptr; + if (mem_other) { + mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa(); + } + kv_base = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/examples/talk-llama/llama-kv-cache-iswa.h b/examples/talk-llama/llama-kv-cache-iswa.h index 70ab22f0d60..dfafc1ef510 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.h +++ b/examples/talk-llama/llama-kv-cache-iswa.h @@ -25,8 +25,10 @@ class llama_kv_cache_iswa : public llama_memory_i { uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache_iswa() = default; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index a49a055a630..2802103bdd8 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -79,6 +79,7 @@ static ggml_tensor * ggml_mul_mat_aux( llama_kv_cache::llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -89,14 +90,30 @@ llama_kv_cache::llama_kv_cache( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : - model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : + model(model), hparams(hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type), + other(static_cast<llama_kv_cache *>(mem_other)), + v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()), + v_cells(*v_cells_impl) { + + // shared cells view the source cache's K/V tensors, so the cell count + // follows the source allocation: a fitted target can be smaller than the + // draft default and oversized views would overflow the source tensors + if (other) { + const uint32_t size_other = other->get_size(); + if (kv_size != size_other) { + LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other); + kv_size = size_other; + } + } GGML_ASSERT(kv_size % n_pad == 0); - const uint32_t n_layer_kv = hparams.n_layer_kv(); + const uint32_t n_layer = hparams.n_layer_all; // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { @@ -111,7 +128,7 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -159,7 +176,7 @@ llama_kv_cache::llama_kv_cache( const bool is_mla = hparams.is_mla(); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); continue; @@ -170,6 +187,24 @@ llama_kv_cache::llama_kv_cache( continue; } + if (share && other) { + const int32_t il_share = share(il); + + if (il_share >= 0) { + const auto & layer_share = other->layers[other->map_layer_ids[il_share]]; + + LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share, + layer_share.k->data, layer_share.v->data); + + map_layer_ids[il] = layers.size(); + + layers.push_back(layer_share); + layers.back().il = il; + + continue; + } + } + if (n_embd_head_k_all == 0) { n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { @@ -229,7 +264,7 @@ llama_kv_cache::llama_kv_cache( if (reuse) { LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer; il++) { const int32_t il_reuse = reuse(il); if (il_reuse < 0) { @@ -253,7 +288,7 @@ llama_kv_cache::llama_kv_cache( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { ggml_backend_buffer_t buf; - if (model.hparams.no_alloc) { + if (hparams.no_alloc) { buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) { t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it @@ -281,23 +316,37 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); - const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; - if (attn_rot_disable) { - LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); - } + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + n_embd_head_k_all = other->n_embd_head_k_all; + n_embd_head_v_all = other->n_embd_head_v_all; + + attn_rot_k = other->attn_rot_k; + attn_rot_v = other->attn_rot_v; + } else { + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; - attn_rot_k = - !attn_rot_disable && - n_embd_head_k_all > 0 && - ggml_is_quantized(type_k) && - hparams.n_embd_head_k() % 64 == 0; + // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer + if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { + attn_rot_k = true; + } - attn_rot_v = - !attn_rot_disable && - n_embd_head_v_all > 0 && - ggml_is_quantized(type_v) && - hparams.n_embd_head_v() % 64 == 0; + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; + } LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); @@ -341,6 +390,11 @@ void llama_kv_cache::clear(bool data) { } bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { @@ -404,6 +458,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { } void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -491,6 +550,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll } void llama_kv_cache::seq_keep(llama_seq_id seq_id) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -513,6 +577,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1"); @@ -558,6 +627,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll } void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1"); @@ -592,6 +666,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in } llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_min(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -600,6 +679,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { } llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_max(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -740,6 +824,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ } bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + bool updated = false; auto * sched = lctx->get_sched(); @@ -1015,6 +1104,11 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -1430,8 +1524,8 @@ struct args_set_input_kq_mask { int64_t n_tps; }; -template<bool causal, bool swa, bool is_2d, bool alibi> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa, bool is_2d, bool alibi> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { //const auto & hparams = args.hparams; const auto & ubatch = args.ubatch; @@ -1445,6 +1539,9 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * const int64_t n_stream = args.n_stream; const int64_t n_tps = args.n_tps; + const T mask_keep = llama_cast<T>(0.0f); + const T mask_drop = llama_cast<T>(-INFINITY); + // the min position in the batch for each sequence llama_pos seq_pos_min[LLAMA_MAX_SEQ]; std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); @@ -1563,46 +1660,55 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * } if (alibi) { - data[idst + j] = -std::abs(p0 - p1); + data[idst + j] = llama_cast<T>(static_cast<float>(-std::abs(p0 - p1))); } else { - data[idst + j] = 0.0f; + data[idst + j] = mask_keep; } continue; skip: - data[idst + j] = -INFINITY; + data[idst + j] = mask_drop; } } } } -template<bool causal, bool swa, bool is_2d> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa, bool is_2d> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool alibi = args.hparams.use_alibi; if (alibi) { - set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data); + set_input_kq_mask_impl<T, causal, swa, is_2d, true> (args, data); } else { - set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data); + set_input_kq_mask_impl<T, causal, swa, is_2d, false>(args, data); } } -template<bool causal, bool swa> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool is_2d = args.ubatch->is_pos_2d(); if (is_2d) { - set_input_kq_mask_impl<causal, swa, true> (args, data); + set_input_kq_mask_impl<T, causal, swa, true> (args, data); } else { - set_input_kq_mask_impl<causal, swa, false>(args, data); + set_input_kq_mask_impl<T, causal, swa, false>(args, data); } } -template<bool causal> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; if (swa) { - set_input_kq_mask_impl<causal, true> (args, data); + set_input_kq_mask_impl<T, causal, true> (args, data); + } else { + set_input_kq_mask_impl<T, causal, false>(args, data); + } +} + +template<typename T> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data, bool causal_attn) { + if (causal_attn) { + set_input_kq_mask_impl<T, true> (args, data); } else { - set_input_kq_mask_impl<causal, false>(args, data); + set_input_kq_mask_impl<T, false>(args, data); } } @@ -1610,7 +1716,6 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; const int64_t n_kv = dst->ne[0]; const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch @@ -1634,10 +1739,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u /*.n_tps =*/ n_tps, }; - if (causal_attn) { - set_input_kq_mask_impl<true> (args, data); + if (dst->type == GGML_TYPE_F16) { + set_input_kq_mask_impl<ggml_fp16_t>(args, (ggml_fp16_t *) dst->data, causal_attn); } else { - set_input_kq_mask_impl<false>(args, data); + set_input_kq_mask_impl<float>(args, (float *) dst->data, causal_attn); } //const int64_t t_end = ggml_time_us(); @@ -1798,6 +1903,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + GGML_ASSERT(!other); + auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); @@ -1843,6 +1951,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co } void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); io.write(&n_stream, sizeof(n_stream)); @@ -1859,7 +1972,19 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla uint32_t cell_range_begin = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + bool add_cell = true; + + add_cell = add_cell && !cells.is_empty(i); + add_cell = add_cell && (seq_id == -1 || cells.seq_has(i, seq_id)); + + // check the cell is not SWA-masked + if (add_cell && seq_id != -1) { + const bool is_masked = llama_hparams::is_masked_swa(n_swa, swa_type, cells.pos_get(i), cells.seq_pos_max(seq_id)); + + add_cell = !is_masked; + } + + if (add_cell) { ++cell_count; if (cell_range_begin == cells.size()) { cell_range_begin = i; @@ -1896,6 +2021,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla } void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); @@ -2112,7 +2242,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 sinfo = find_slot(ubatch, false); if (sinfo.empty()) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to find %d available cells in kv cache\n", __func__, cell_count); return false; } diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 0b62dc7b232..3d68f98c142 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -93,8 +93,12 @@ class llama_kv_cache : public llama_memory_i { using slot_info_vec_t = std::vector<slot_info>; + // TODO: refactor the memory instances to not depend on `llama_model` + // instead pass all necessary info (e.g. hparams, dev layers, arch, etc.) directly + // likely through `struct llama_memory_params` llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -105,8 +109,10 @@ class llama_kv_cache : public llama_memory_i { uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache() = default; @@ -260,7 +266,12 @@ class llama_kv_cache : public llama_memory_i { // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector<uint32_t> v_heads; - std::vector<llama_kv_cells> v_cells; + // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] + llama_kv_cache * other; + + std::shared_ptr<llama_kv_cells_vec> v_cells_impl; + + llama_kv_cells_vec & v_cells; // maps from a sequence id to a stream id std::vector<uint32_t> seq_to_stream; diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h index 10063bf4272..fddd31a0b21 100644 --- a/examples/talk-llama/llama-kv-cells.h +++ b/examples/talk-llama/llama-kv-cells.h @@ -531,3 +531,5 @@ class llama_kv_cells { } } }; + +using llama_kv_cells_vec = std::vector<llama_kv_cells>; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 72f5c2fea72..c7d4bcd413e 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_ubatch, n_pad, + nullptr, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( @@ -57,7 +59,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 33b3b395e0c..f2d49cbce54 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -33,6 +33,7 @@ llama_memory_hybrid::llama_memory_hybrid( hparams(model.hparams), mem_attn(new llama_kv_cache( model, + model.hparams, type_k, type_v, v_trans, @@ -43,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid( n_pad, n_swa, swa_type, + nullptr, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( @@ -57,7 +60,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index ec5dc5835dd..6a4892fb471 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -26,7 +26,7 @@ llama_memory_recurrent::llama_memory_recurrent( uint32_t n_seq_max, uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { - const int32_t n_layer = hparams.n_layer; + const int32_t n_layer = hparams.n_layer(); head = 0; size = mem_size; @@ -863,7 +863,7 @@ void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std:: void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const { const uint32_t s_trans = 0; - const uint32_t n_layer = hparams.n_layer; + const uint32_t n_layer = hparams.n_layer(); io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); @@ -1047,8 +1047,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell io.read(&s_trans, sizeof(s_trans)); io.read(&n_layer, sizeof(n_layer)); - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + if (n_layer != hparams.n_layer()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer()); return false; } if (cell_count > size) { diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4ad1612e45b..db825396645 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -23,6 +23,8 @@ struct llama_memory_params { bool swa_full; llama_context_type ctx_type; + + llama_memory_t mem_other; }; enum llama_memory_status { @@ -76,6 +78,8 @@ struct llama_memory_i { // return negative value to indicate that the layer il should not reuse memory using layer_reuse_cb = std::function<int32_t(int32_t il)>; + using layer_share_cb = std::function<int32_t(int32_t il)>; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index c645d0785ab..474cabdfc09 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -146,7 +146,7 @@ namespace GGUFMeta { const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { arr_type, - size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_n(ctx, k), arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } @@ -393,6 +393,8 @@ namespace GGUFMeta { } template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required); + template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required); + template bool llama_model_loader::get_arr<std::vector<int32_t>>(enum llm_kv kid, std::vector<int32_t> & result, bool required); template<typename T> bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { @@ -445,7 +447,7 @@ namespace GGUFMeta { } if (n > N_MAX) { - throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", n, (uint32_t) N_MAX, key.c_str())); } if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) { @@ -502,9 +504,9 @@ namespace GGUFMeta { } // TODO: this is not very clever - figure out something better - template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<int, 4>> (enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required); - template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( @@ -1050,10 +1052,10 @@ struct ggml_tensor * llama_model_loader::create_tensor( if (it == ctx_map.end()) { // one ggml context per buffer type int max_n_tensors = n_tensors; - max_n_tensors += 1; // duplicated output tensor - max_n_tensors += hparams.n_layer*2; // duplicated rope freq tensors + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += hparams.n_layer()*2; // duplicated rope freq tensors if (files.empty()) { - max_n_tensors += hparams.n_layer*256; // this should be well above what any model actually uses + max_n_tensors += hparams.n_layer()*256; // this should be well above what any model actually uses } const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 528e4c9c069..67d4a9df0f0 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -14,9 +14,6 @@ bool llama_model_saver_supports_arch(llm_arch arch) { switch (arch) { - case LLM_ARCH_QWEN3NEXT: - case LLM_ARCH_QWEN35: - case LLM_ARCH_QWEN35MOE: case LLM_ARCH_PLAMO3: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: @@ -29,6 +26,7 @@ bool llama_model_saver_supports_arch(llm_arch arch) { case LLM_ARCH_APERTUS: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_MELLUM: return false; default: return true; @@ -79,7 +77,7 @@ void llama_model_saver::add_kv(const enum llm_kv key, const char value) { template <typename Container> void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { GGML_ASSERT(model != nullptr || !per_layer); - const size_t n_values = per_layer ? size_t(model->hparams.n_layer) : value.size(); + const size_t n_values = per_layer ? size_t(model->hparams.n_layer()) : value.size(); GGML_ASSERT(n_values <= value.size()); if (n_values == 0) { @@ -106,6 +104,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); } else if (std::is_same<typename Container::value_type, uint32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same<typename Container::value_type, bool>::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_BOOL, value.data(), n_values); } else if (std::is_same<typename Container::value_type, int32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); } else if (std::is_same<typename Container::value_type, float>::value) { @@ -206,7 +206,7 @@ void llama_model_saver::add_kv_from_model() { if (hparams.n_embd_out_impl > 0) { add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); } - add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -227,8 +227,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); add_kv(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); add_kv(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers); - add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers); + add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn); add_kv(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers); + add_kv(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr); add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); @@ -244,7 +245,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); // saved as LLM_KV_ATTENTION_RECURRENT_LAYERS instead add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); @@ -278,6 +279,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + add_kv(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, true); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 0d21b2a53c5..7281ed79f10 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -10,6 +10,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -44,6 +45,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_llama_embed(params); case LLM_ARCH_MAINCODER: return new llama_model_maincoder(params); + case LLM_ARCH_TALKIE: + return new llama_model_talkie(params); case LLM_ARCH_DECI: return new llama_model_deci(params); case LLM_ARCH_BAICHUAN: @@ -78,6 +81,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_mpt(params); case LLM_ARCH_STABLELM: return new llama_model_stablelm(params); + case LLM_ARCH_MELLUM: + return new llama_model_mellum(params); case LLM_ARCH_QWEN: return new llama_model_qwen(params); case LLM_ARCH_QWEN2: @@ -134,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_gemma3n(params); case LLM_ARCH_GEMMA4: return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA4_ASSISTANT: + return new llama_model_gemma4_assistant(params); case LLM_ARCH_GEMMA_EMBEDDING: return new llama_model_gemma_embedding(params); case LLM_ARCH_STARCODER2: @@ -170,6 +177,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_deepseek2(params); case LLM_ARCH_DEEPSEEK2OCR: return new llama_model_deepseek2ocr(params); + case LLM_ARCH_DEEPSEEK32: + return new llama_model_deepseek32(params); case LLM_ARCH_GLM_DSA: return new llama_model_glm_dsa(params); case LLM_ARCH_MISTRAL4: @@ -278,6 +287,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35moe(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); + case LLM_ARCH_EAGLE3: + return new llama_model_eagle3(params); case LLM_ARCH_MIMO2: return new llama_model_mimo2(params); case LLM_ARCH_KIMI_LINEAR: @@ -366,10 +377,10 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // count only the same type of previous layers to avoid this auto get_il_eff = [&](const size_t il){ size_t ret = 0; - const bool il_is_recurrent = hparams.is_recurrent(il); - const bool il_is_swa = hparams.is_swa(il); + const bool il_is_recr = hparams.is_recr(il); + const bool il_is_swa = hparams.is_swa(il); for (size_t il_prev = 0; il_prev < il; il_prev++) { - ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + ret += hparams.is_recr(il_prev) == il_is_recr && hparams.is_swa(il_prev) == il_is_swa; } return ret; }; @@ -391,7 +402,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str rotation = get_il_eff(il) % ud->n_devices; } else { il = 0; - rotation = hparams.n_layer % ud->n_devices; + rotation = hparams.n_layer() % ud->n_devices; } const ggml_tensor * tensor_axis_0 = suffix.empty() ? tensor : ud->model->get_tensor((prefix + suffix).c_str()); if (tensor_axis_0 == nullptr) { @@ -405,16 +416,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str auto get_tensor_config = [&]() -> tensor_config { // standard attention if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_q_bias) || std::regex_match(tensor_name, pattern_kv_bias)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_qkv_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if ( std::regex_match(tensor_name, pattern_qkv_bias)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_qk_norm)) { return get_tensor_config_impl(tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); @@ -430,7 +441,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str } if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); @@ -483,7 +494,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); }; - auto get_split_segments = [&](int axis, uint32_t il) -> std::vector<int64_t> { + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector<std::pair<int64_t, uint32_t>> { if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { const int64_t head_k_dim = hparams.ssm_d_state; const int64_t head_v_dim = hparams.ssm_d_state; @@ -498,26 +509,26 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return {key_dim, key_dim, value_dim}; + return {{key_dim, 2}, {value_dim, 1}}; } } else { const int64_t head_ratio = n_v_heads / n_k_heads; if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return std::vector<int64_t>(2 + head_ratio, key_dim); + return {{key_dim, 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { - return std::vector<int64_t>(head_ratio, key_dim); + return {{key_dim, head_ratio}}; } if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { - return std::vector<int64_t>(head_ratio, n_k_heads); + return {{n_k_heads, head_ratio}}; } if (std::regex_match(tensor_name, pattern_r_cache)) { - return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + return {{key_dim * (hparams.ssm_d_conv - 1), 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_s_cache)) { - return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim); + return {{n_k_heads * head_v_dim * head_v_dim, head_ratio}}; } } @@ -525,9 +536,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { @@ -535,21 +546,23 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); - return {n_embd, n_embd_gqa, n_embd_gqa}; + return {{n_embd, 1}, {n_embd_gqa, 2}}; } if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; }; - auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<int64_t> & segments) -> std::vector<int64_t> { - if (hparams.is_recurrent(il)) { + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t, uint32_t>> & segments) -> std::vector<int64_t> { + // for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used + if (hparams.is_recr(il)) { // linear attention - const int64_t head_dim = hparams.ssm_d_state; - const int64_t granularity_qkv = std::lcm(blck_size, head_dim); + const int64_t head_dim = hparams.ssm_d_state; + const int64_t blck_size_perf = std::lcm(blck_size, 128); + const int64_t granularity_qkv = std::lcm(blck_size_perf, head_dim); if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { return std::vector<int64_t>(segments.size(), granularity_qkv); @@ -571,17 +584,24 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // regular attention const uint32_t n_gqa = hparams.n_gqa(il); const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); + + // to handle head sizes like 80, only increase granularity while it doesn't cause underutilization + int64_t blck_size_perf = blck_size; + while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) { + blck_size_perf *= 2; + } + if (std::regex_match(tensor_name, pattern_attn_sinks)) { GGML_ASSERT(segments.size() == 1); - return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa}; + return {std::lcm(n_embd_q, blck_size_perf)/n_embd_q * n_gqa}; } - const int64_t granularity_q = std::lcm(n_embd_q, blck_size); + const int64_t granularity_q = std::lcm(n_embd_q, blck_size_perf); if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { GGML_ASSERT(segments.size() == 1); // some models have Q gate tensors, for those cases the granularity needs to be doubled: if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { - return {std::lcm(2*n_embd_q, blck_size)}; + return {std::lcm(2*n_embd_q, blck_size_perf)}; } return {granularity_q}; } @@ -598,16 +618,17 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return {granularity_kv}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { - GGML_ASSERT(segments.size() == 3); - return {granularity_q, granularity_kv, granularity_kv}; + GGML_ASSERT(segments.size() == 2); + return {granularity_q, granularity_kv}; } } // FFN if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { - GGML_ASSERT(segments.size() <= 2); - return std::vector<int64_t>(segments.size(), blck_size); + const int64_t blck_size_perf = std::lcm(blck_size, 128); + GGML_ASSERT(segments.size() == 1); + return {blck_size_perf}; } // everything else @@ -620,7 +641,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_config tc = get_tensor_config(); split_state.axis = tc.axis; if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { - const int64_t ne_full = tensor->ne[split_state.axis]; const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); const float * tensor_split = ud->model->tensor_split(); std::vector<float> tensor_split_scan; @@ -631,12 +651,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_split_scan[j] += tensor_split_scan[j - 1]; } } - const std::vector<int64_t> segments = get_split_segments(split_state.axis, tc.il); + const std::vector<std::pair<int64_t, uint32_t>> segments = get_split_segments(split_state.axis, tc.il); const std::vector<int64_t> granularity = get_split_granularity(blck_size, tc.il, segments); for (size_t is = 0; is < segments.size(); is++) { - const int64_t ne_s = segments[is]; - const int64_t g_s = granularity[is]; - GGML_ASSERT(ne_full % g_s == 0); + const int64_t ne_s = segments[is].first; + const uint32_t nr_s = segments[is].second; + const int64_t g_s = granularity[is]; int64_t low = 0; size_t j = 0; for (; j < ud->n_devices - 1; j++) { @@ -649,10 +669,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str low = high; } split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + split_state.nr[is] = nr_s; } split_state.n_segments = segments.size(); } else { memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.nr[0] = 1; split_state.n_segments = 1; } return split_state; @@ -756,6 +778,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B"; + case LLM_TYPE_12B_A2_5B: return "12B.A2.5B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; @@ -777,6 +800,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_397B_A17B: return "397B.A17B"; + case LLM_TYPE_685B_A37B: return "685B.A37B"; case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -813,6 +837,28 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } +// Maps the GGUF `<arch>.hidden_activation` string to the FFN op type used by the +// graph builders. Only gated activations that map cleanly to llm_ffn_op_type are +// listed; unrecognized values fall back to GeGLU, which matches the historical +// default for ModernBert-style architectures. +static const std::map<std::string, llm_ffn_op_type> LLM_FFN_OP_TYPES_FROM_STRING = { + { "gelu", LLM_FFN_GEGLU }, + { "geglu", LLM_FFN_GEGLU }, + { "silu", LLM_FFN_SWIGLU }, + { "swish", LLM_FFN_SWIGLU }, + { "swiglu", LLM_FFN_SWIGLU }, + { "relu", LLM_FFN_RELU }, + { "reglu", LLM_FFN_REGLU }, +}; + +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback) { + const auto it = LLM_FFN_OP_TYPES_FROM_STRING.find(name); + if (it != LLM_FFN_OP_TYPES_FROM_STRING.end()) { + return it->second; + } + return fallback; +} + // CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector<llama_device> & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; @@ -1000,7 +1046,7 @@ void llama_model_base::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); @@ -1042,28 +1088,29 @@ void llama_model_base::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - std::fill( - hparams.recurrent_layer_arr.begin(), - hparams.recurrent_layer_arr.end(), - llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.is_swa_impl.begin(), hparams.is_swa_impl.end(), 0); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), llm_arch_is_recurrent(ml.get_arch()) ? 1 : 0); std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); - std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); - std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer(), false); + + // Populate deepstack_mapping_arr - initialized to -1 (no deepstack) + std::fill(hparams.deepstack_mapping_arr.begin(), hparams.deepstack_mapping_arr.end(), -1); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer(), false); bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -1162,7 +1209,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { const auto & use_mlock = params.use_mlock; const auto & tensor_split = params.tensor_split; - const int n_layer = hparams.n_layer; + const int n_layer_all = hparams.n_layer_all; const int n_gpu_layers = this->n_gpu_layers(); const bool use_mmap_buffer = true; @@ -1219,10 +1266,10 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { splits[i] /= split_sum; } - const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); + const int i_gpu_start = std::max(n_layer_all + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, n_layer_all + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); + const bool is_swa = il < n_layer_all && hparams.is_swa(il); if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); return {cpu_dev, &pimpl->cpu_buft_list}; @@ -1238,13 +1285,13 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; // assign the repeating layers to the devices according to the splits - pimpl->dev_layer.resize(n_layer); - for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer.resize(n_layer_all); + for (int il = 0; il < n_layer_all; ++il) { pimpl->dev_layer[il] = get_layer_buft_list(il); } // assign the output layer - pimpl->dev_output = get_layer_buft_list(n_layer); + pimpl->dev_output = get_layer_buft_list(n_layer_all); const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; @@ -1260,14 +1307,14 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { throw std::runtime_error("model has expert layers but no expert layers are used"); } - layers.resize(n_layer); + layers.resize(n_layer_all); // call the per-model loading function load_arch_tensors(ml); // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) // this avoids having to add scale loading to every architecture - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { auto & layer = layers[i]; // attention weight scales (per-tensor, shape {1}) @@ -1525,7 +1572,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } if (llama_supports_gpu_offload()) { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + const int n_gpu = std::min(n_gpu_layers, n_layer_all); int n_repeating = n_gpu; if (n_repeating > 0) { @@ -1534,8 +1581,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; + const int max_backend_supported_layers = n_layer_all + 1; + const int max_offloadable_layers = n_layer_all + 1; LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); } @@ -1604,7 +1651,8 @@ const float * llama_model::tensor_split() const { } uint32_t llama_model::n_gpu_layers() const { - return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; + // note: plus 1 for the "output" layer + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer_all + 1; } llama_split_mode llama_model::split_mode() const { @@ -1637,10 +1685,10 @@ uint64_t llama_model::n_elements() const { void llama_model::print_info() const { const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) { + auto print_f = [](const std::function<int32_t(uint32_t)> & f, uint32_t n) { bool is_var = false; - std::vector<uint32_t> v; + std::vector<int32_t> v; for (uint32_t i = 0; i < n; ++i) { v.push_back(f(i)); if (v[i] != v[0]) { @@ -1673,19 +1721,21 @@ void llama_model::print_info() const { if (!hparams.vocab_only) { LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer()); + LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); @@ -1693,7 +1743,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); @@ -1714,6 +1764,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + if (arch == LLM_ARCH_GRANITE && + std::any_of(hparams.deepstack_mapping_arr.begin(), + hparams.deepstack_mapping_arr.end(), + [](const auto & entry) { return entry >= 0; })) { + LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__, + print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; }, + hparams.n_layer_all).c_str()); + } // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); @@ -1767,7 +1825,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_DEEPSEEK32 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -1785,7 +1843,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + if (arch == LLM_ARCH_MELLUM || + arch == LLM_ARCH_QWEN3MOE || + arch == LLM_ARCH_OPENAI_MOE || + arch == LLM_ARCH_QWEN3VLMOE || + arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -1816,7 +1878,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + LLAMA_LOG_INFO("%s: n_layer_nextn = %d\n", __func__, hparams.n_layer_nextn); } if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { @@ -1955,6 +2017,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; + case LLM_ARCH_DEEPSEEK32: + { + res = new llama_kv_cache_dsa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + nullptr); + } break; // Models that need standard caching should rely on recurrent/hybrid // checks default: @@ -1981,22 +2060,21 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; if (arch == LLM_ARCH_FALCON_H1) { - filter_attn = [&](int32_t) { return true; }; - filter_recr = [&](int32_t) { return true; }; + filter_attn = [&](uint32_t) { return true; }; + filter_recr = [&](uint32_t) { return true; }; } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - filter_attn = [&](int32_t il) { - return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + filter_attn = [&](uint32_t il) { + return !hparams.is_recr(il) && hparams.n_ff(il) == 0; }; - filter_recr = [&](int32_t il) { - return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + filter_recr = [&](uint32_t il) { + return hparams.is_recr(il) && hparams.n_ff(il) == 0; }; } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; - filter_attn = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && !hparams.is_recurrent(il); + filter_attn = [&](uint32_t il) { + return il < hparams.n_layer() && !hparams.is_recr(il); }; - filter_recr = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && hparams.is_recurrent(il); + filter_recr = [&](uint32_t il) { + return il < hparams.n_layer() && hparams.is_recr(il); }; } @@ -2041,13 +2119,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_recr */ std::move(filter_recr)); } } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; llama_kv_cache::layer_filter_cb filter = nullptr; + llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_share_cb share = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { - reuse = [&](int32_t il) { - if (il >= (int32_t) hparams.n_layer_kv_from_start) { - return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + reuse = [&](uint32_t il) { + GGML_ASSERT(hparams.n_layer_kv_from_start >= 2); + + if (il >= (uint32_t)hparams.n_layer_kv_from_start) { + return hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); } return -1; @@ -2055,32 +2136,73 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } if (mtp_on_hybrid_qwen35) { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; - filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } + + if (arch == LLM_ARCH_STEP35 && hparams.n_layer_nextn > 0) { + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP) { + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } else { + filter = [&](uint32_t il) { return il < hparams.n_layer(); }; + } } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - filter, - reuse); + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + llama_memory_t mem_other = llama_get_memory(cparams.ctx_other); + + share = [&](int32_t il) { + const llama_model * model_other = llama_get_model(cparams.ctx_other); + + if (hparams.is_swa(il)) { + return llama_model_n_layer(model_other) - 2; + } + + return llama_model_n_layer(model_other) - 1; + }; + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + mem_other, + filter, + reuse, + share); + } else { + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + nullptr, + filter, + reuse, + share); + } } else { GGML_ASSERT(!hparams.is_swa_any()); res = new llama_kv_cache( *this, + hparams, params.type_k, params.type_v, !cparams.flash_attn, @@ -2091,7 +2213,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, + nullptr, filter, + nullptr, nullptr); } } @@ -2116,7 +2240,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // TODO: move reranking logic here and generalize llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); - llm->res->set_outputs(); + llm->res->set_outputs(params); return llm->res->get_gf(); } @@ -2179,7 +2303,7 @@ int32_t llama_model_n_embd_out(const llama_model * model) { } int32_t llama_model_n_layer(const llama_model * model) { - return model->hparams.n_layer; + return model->hparams.n_layer(); } int32_t llama_model_n_head(const llama_model * model) { @@ -2270,6 +2394,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2OCR: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -2283,6 +2408,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_EAGLE3: case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: @@ -2323,6 +2449,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -2353,6 +2480,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_TALKIE: + case LLM_ARCH_MELLUM: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -2474,8 +2603,9 @@ uint64_t llama_model_n_params(const llama_model * model) { bool llama_model_has_encoder(const llama_model * model) { switch (model->arch) { - case LLM_ARCH_T5: return true; - case LLM_ARCH_T5ENCODER: return true; + case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: + case LLM_ARCH_EAGLE3: return true; default: return false; } } @@ -2561,3 +2691,12 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); } } + +const int32_t * llama_model_target_layer_ids(const struct llama_model * model) { + const auto & v = model->target_layer_ids; + return v.empty() ? nullptr : v.data(); +} + +uint32_t llama_model_target_layer_ids_n(const struct llama_model * model) { + return (uint32_t) model->target_layer_ids.size(); +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 398a0aa725c..f4718f6d584 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_7B_A1B, LLM_TYPE_8B_A1B, // lfm2moe + LLM_TYPE_12B_A2_5B, LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe @@ -137,6 +138,7 @@ enum llm_type { LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_397B_A17B, // Qwen3.5 + LLM_TYPE_685B_A37B, // DeepSeek V3.2 LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, @@ -144,6 +146,10 @@ enum llm_type { std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); +// Map a GGUF activation-name string to llm_ffn_op_type. Returns `fallback` if +// the string is empty or not recognized. +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; @@ -488,7 +494,7 @@ struct llama_layer { struct ggml_tensor * indexer_attn_k = nullptr; struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias - // gemma4 layer output scale + // gemma4 layer output scale, reused for talkie embedding skip scale struct ggml_tensor * out_scale = nullptr; struct llama_layer_posnet posnet; @@ -542,6 +548,10 @@ struct llama_model { struct ggml_tensor * output_s = nullptr; struct ggml_tensor * output_in_s = nullptr; + // NextN/MTP model-level projections + struct ggml_tensor * nextn_proj_pre = nullptr; + struct ggml_tensor * nextn_proj_post = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; @@ -559,6 +569,13 @@ struct llama_model { struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; + // eagle3 + struct ggml_tensor * fc = nullptr; // feature fusion layer + struct ggml_tensor * d2t = nullptr; // draft to target vocabulary mapping + + // unified vector to store target-model extracted layer ids in eagle3, dflash, etc. + std::vector<int32_t> target_layer_ids; + std::vector<llama_layer> layers; //Dense linear projections for SentenceTransformers models like embeddinggemma @@ -694,7 +711,9 @@ const char * llm_type_name(llm_type type); // convenience macro for loading local variables for load_tensors() in llama_model_base // note: cast to int64_t since we will use these for the tensor dimensions #define LLAMA_LOAD_LOCALS \ - const int n_layer = hparams.n_layer; GGML_UNUSED(n_layer); \ + const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \ + const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \ + const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \ const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 43e05c3d56f..cf92ce4bb8b 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -847,7 +847,7 @@ static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<t qs.has_tied_embeddings = false; } } - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer(); } // @@ -1348,7 +1348,7 @@ llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * des model->hparams.n_embd = desc->n_embd; model->hparams.n_embd_head_k_full = desc->n_embd_head_k; model->hparams.n_embd_head_v_full = desc->n_embd_head_v; - model->hparams.n_layer = desc->n_layer; + model->hparams.n_layer_all = desc->n_layer; model->hparams.n_expert = desc->n_expert; for (uint32_t i = 0; i < desc->n_layer; i++) { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a5cf148b268..8543e178dba 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -353,6 +353,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: case LLAMA_VOCAB_PRE_TYPE_MINERVA: + case LLAMA_VOCAB_PRE_TYPE_MELLUM2: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -432,6 +433,15 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI: + // Same lookaheads as GPT4O but with \p{M} added so combining marks + // (diacritics) attach to their base letters. Avoids excessive + // backtracking on scripts that use them heavily (Bengali, Hindi, + // Telugu, Thai, ...). See PR #22716 for benchmarks. + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))*((?=[\\p{L}\\p{M}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))+((?=[\\p{L}\\p{M}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: regex_exprs = { // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" @@ -511,6 +521,21 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; byte_encode = false; break; + case LLAMA_VOCAB_PRE_TYPE_MINICPM5: + regex_exprs = { + // original regex from tokenizer.json (openbmb/MiniCPM5-1B) + "\\p{N}{1,3}", + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_WHITESPACE: + // whitespace pre-tokenizer (jinaai/jina-embeddings-v2-base-zh) + regex_exprs = { + "\\S+", + }; + byte_encode = false; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -739,7 +764,7 @@ struct llm_tokenizer_wpm_session { void tokenize(const std::string & text, std::vector<llama_token> & output) { // normalize and split by whitespace - std::vector<std::string> words = preprocess(text); + std::vector<std::string> words = preprocess(text, vocab.get_normalizer_opts()); // bos token prepended already // find the longest tokens that form the words @@ -784,11 +809,14 @@ struct llm_tokenizer_wpm_session { } // TODO: reduce string copies by using cpts_offs array - static std::vector<std::string> preprocess(const std::string & text) { - const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + static std::vector<std::string> preprocess(const std::string & text, const llama_vocab::normalizer_options & normalizer_opts) { + std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text); + if (normalizer_opts.strip_accents) { + cpts = unicode_cpts_normalize_nfd(cpts); + } std::vector<std::string> words(1, ""); - for (const uint32_t cpt : cpts_nfd) { + for (const uint32_t cpt : cpts) { const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { @@ -803,7 +831,11 @@ struct llm_tokenizer_wpm_session { continue; } - const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + if (normalizer_opts.strip_accents && flags.is_accent_mark) { + continue; + } + + const std::string s = unicode_cpt_to_utf8(normalizer_opts.lowercase ? unicode_tolower(cpt) : cpt); if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); @@ -1663,6 +1695,35 @@ struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { const llama_vocab & vocab; }; +struct llm_tokenizer_whitespace_session : llm_tokenizer_bpe_session { + llm_tokenizer_whitespace_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector<llama_token> & output) override { + const bool lowercase = vocab.get_normalizer_opts().lowercase; + + std::string segment; + auto flush = [&]() { + if (!segment.empty()) { + llm_tokenizer_bpe_session::tokenize(segment, output); + segment.clear(); + } + }; + + for (uint32_t cpt : unicode_cpts_from_utf8(text)) { + // drop whitespace + if (unicode_cpt_flags_from_cpt(cpt).is_whitespace) { + flush(); + } else { + segment += unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); + } + } + flush(); + } + +private: + const llama_vocab & vocab; +}; + // // impl // @@ -1744,6 +1805,9 @@ struct llama_vocab::impl { bool escape_whitespaces = true; bool treat_whitespace_as_suffix = false; + // BertNormalizer options + llama_vocab::normalizer_options normalizer_opts; + std::unordered_map<std::string, llama_token> token_to_id; std::vector<token_data> id_to_token; @@ -1760,6 +1824,8 @@ struct llama_vocab::impl { // set of all tokens that cause "end of generation" std::set<llama_token> special_eog_ids; + std::vector<llama_token> suppress_tokens; + std::unique_ptr<llm_tokenizer> tokenizer; std::vector<char> precompiled_charsmap; @@ -1892,7 +1958,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna" || tokenizer_model == "whitespace") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2039,6 +2105,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if (tokenizer_pre == "default") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "minicpm5") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINICPM5; + ignore_merges = true; } else if ( tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || @@ -2094,7 +2163,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( - tokenizer_pre == "gemma4") { + tokenizer_pre == "gemma4" || + tokenizer_pre == "granite-embed-multi-311m") { pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; escape_whitespaces = true; } else if ( @@ -2108,6 +2178,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "roberta-bpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; add_sep = true; + } else if ( + tokenizer_pre == "whitespace") { + pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE; + normalizer_opts.lowercase = false; } else if ( tokenizer_pre == "refact") { pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; @@ -2196,9 +2270,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "gpt-4o" || tokenizer_pre == "llama4" || - tokenizer_pre == "kanana2") { + tokenizer_pre == "kanana2" || + tokenizer_pre == "talkie") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-embed-multi-97m") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "tiny_aya") { pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; @@ -2257,6 +2337,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "solar-open") { pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; + } else if ( + tokenizer_pre == "mellum2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MELLUM2; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2458,6 +2541,21 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } } + // BertNormalizer options + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_opts.lowercase, false); + normalizer_opts.strip_accents = normalizer_opts.lowercase; + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, normalizer_opts.strip_accents, false); + + // suppress tokens + { + const int suppress_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SUPPRESS_TOKENS).c_str()); + if (suppress_idx != -1) { + const int n = gguf_get_arr_n(ctx, suppress_idx); + const int32_t * data = (const int32_t *) gguf_get_arr_data(ctx, suppress_idx); + suppress_tokens.assign(data, data + n); + } + } + // auto-detect special tokens by text // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... // for now, we apply this workaround to find the tokens based on their text @@ -3252,6 +3350,8 @@ std::vector<llama_token> llama_vocab::impl::tokenize( std::unique_ptr<llm_tokenizer_bpe_session> session; if (vocab.get_tokenizer_model() == "hybriddna") { session = std::make_unique<llm_tokenizer_hybriddna_session>(vocab, *tok_bpe); + } else if (vocab.get_tokenizer_model() == "whitespace") { + session = std::make_unique<llm_tokenizer_whitespace_session>(vocab, *tok_bpe); } else { session = std::make_unique<llm_tokenizer_bpe_session>(vocab, *tok_bpe); } @@ -3880,6 +3980,14 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const { return pimpl->treat_whitespace_as_suffix; } +const llama_vocab::normalizer_options & llama_vocab::get_normalizer_opts() const { + return pimpl->normalizer_opts; +} + +const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const { + return pimpl->suppress_tokens; +} + int llama_vocab::max_token_len() const { return pimpl->max_token_len; } diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 8b040b912e2..707cd4bac4b 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -8,58 +8,62 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, - LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, - LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, - LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, - LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, - LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, - LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, - LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, - LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, - LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, - LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, - LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, - LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, + LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, + LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, + LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, + LLAMA_VOCAB_PRE_TYPE_WHITESPACE = 53, + LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI = 54, + LLAMA_VOCAB_PRE_TYPE_MELLUM2 = 55, }; struct LLM_KV; @@ -72,6 +76,12 @@ struct llama_vocab { llama_token_attr attr; }; + struct normalizer_options { + bool lowercase = true; + bool strip_accents = true; + // TODO: clean_text, handle_chinese_chars + }; + llama_vocab(); ~llama_vocab(); @@ -137,6 +147,9 @@ struct llama_vocab { bool get_remove_extra_whitespaces () const; bool get_escape_whitespaces () const; bool get_treat_whitespace_as_suffix() const; + const normalizer_options & get_normalizer_opts() const; + + const std::vector<llama_token> & get_suppress_tokens() const; int max_token_len() const; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index dfe30ce8f61..a67fa8039a4 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -225,7 +225,9 @@ static bool llama_prepare_model_devices(const llama_model_params & params, llama } case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back({false, dev}); + if (igpus.empty()) { + igpus.push_back({false, dev}); + } break; case GGML_BACKEND_DEVICE_TYPE_META: GGML_ABORT("fatal error"); @@ -239,8 +241,9 @@ static bool llama_prepare_model_devices(const llama_model_params & params, llama // add GPUs model->devices.insert(model->devices.end(), gpus.begin(), gpus.end()); - // add integrated GPUs only if no other devices were found - if (model->devices.empty()) { + // add integrated GPUs only if no discrete GPUs were found + // (RPC servers do not count, otherwise the local iGPU would be dropped on iGPU+RPC setups) + if (gpus.empty()) { model->devices.insert(model->devices.end(), igpus.begin(), igpus.end()); } } diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index e8374c53b70..27e48067428 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -339,6 +339,7 @@ extern "C" { uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] + uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch) int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -387,6 +388,10 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + // a source/target/parent context + // can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts + struct llama_context * ctx_other; }; struct llama_model_tensor_override { @@ -975,7 +980,11 @@ extern "C" { // Set whether the model is in warmup mode or not // If true, all model tensors are activated during llama_decode() to load and cache their weights. - LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // + // note: using this can cause extra graph reallocations because it changes the graph topology with MoE models, + // so it is generally not recommended to use in practice. will be removed in the future + DEPRECATED(LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup), + "user code should do warmup runs manually [TAG_LLAMA_GRAPH_NO_WARMUP]"); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index a7c77ee5d28..063b214256e 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -30,7 +30,7 @@ void llama_model_afmoe::load_arch_hparams(llama_model_loader & ml) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 56: type = LLM_TYPE_6B; break; case 32: type = LLM_TYPE_26B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index bec7136521c..6dfb8905fbe 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -2,12 +2,13 @@ void llama_model_apertus::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); - switch (hparams.n_layer) { + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer()); + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index d086c4717ff..9536e7c5d42 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -4,7 +4,7 @@ void llama_model_arcee::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); // Arcee uses the same structure as Llama - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index 27deadffeb7..09ee0f752f0 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -4,7 +4,7 @@ void llama_model_arctic::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 128) { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 35: type = LLM_TYPE_10B_128x3_66B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 9bd04127b25..b38b2064785 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -10,7 +10,7 @@ void llama_model_arwkv7::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: switch (hparams.n_embd) { case 768: type = LLM_TYPE_190M; break; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4d26081cd5d..585f3614174 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -2,7 +2,7 @@ void llama_model_baichuan::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index fe1ae10864b..7faf73c835b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -8,7 +8,7 @@ void llama_model_bailingmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_16B; break; case 88: type = LLM_TYPE_290B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 2f0d44a6259..5000e9c6db8 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -9,17 +9,13 @@ void llama_model_bailingmoe2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 20: type = LLM_TYPE_16B_A1B; break; - case 21: type = LLM_TYPE_16B_A1B; break; case 32: type = LLM_TYPE_100B_A6B; break; - case 33: type = LLM_TYPE_100B_A6B; break; default: type = LLM_TYPE_UNKNOWN; } } @@ -39,9 +35,9 @@ void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -78,7 +74,7 @@ void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); @@ -112,8 +108,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph ggml_tensor * inp_out_ids = build_inp_out_ids(); - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -146,7 +141,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 3c28f419ccf..53ce29f23ca 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -1,9 +1,9 @@ #include "models.h" void llama_model_bert::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 3: type = LLM_TYPE_17M; break; // bge-micro case 6: diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index 7e8125deec4..c8330274580 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -3,7 +3,7 @@ void llama_model_bitnet::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index 30b0f3d07d0..609d2ddf998 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -3,7 +3,7 @@ void llama_model_bloom::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 30: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 4bceaefd63b..4f45acecf84 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -6,7 +6,7 @@ void llama_model_chameleon::load_arch_hparams(llama_model_loader & ml) { hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_34B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 6766fa71c15..7ae5b938fde 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -2,7 +2,8 @@ void llama_model_chatglm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: { if (hparams.n_head(0) == 16) { type = LLM_TYPE_1_5B; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index 274dd3342a7..de53bb98184 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -2,7 +2,8 @@ void llama_model_codeshell::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 42: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 2e231bb3f93..750f57a394e 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -2,7 +2,8 @@ void llama_model_cogvlm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index a514cf88fc6..61a5945a194 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -5,6 +5,7 @@ void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { uint32_t swa_period = 4; ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -12,7 +13,8 @@ void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index adf7fcaa20f..94a46188bb8 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -3,7 +3,8 @@ void llama_model_command_r::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_35B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index af71c775365..4f5ac4d06a4 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -1,14 +1,14 @@ #include "models.h" void llama_model_dbrx::load_arch_hparams(llama_model_loader & ml) { -ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); -switch (hparams.n_layer) { - case 40: type = LLM_TYPE_16x12B; break; - default: type = LLM_TYPE_UNKNOWN; + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; + } } - } void llama_model_dbrx::load_arch_tensors(llama_model_loader &) { LLAMA_LOAD_LOCALS; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 567e3535276..cdfcf29e02f 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -2,7 +2,8 @@ void llama_model_deci::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; case 162: type = LLM_TYPE_405B; break; diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index 1fe54adc13e..a9e8bc51403 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -5,7 +5,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + const bool is_lite = (hparams.n_layer() == 27 || hparams.n_layer() == 26 || (hparams.n_layer() == 48 && n_vocab == 128256)); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); @@ -23,7 +23,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + if ((hparams.n_layer() == 47 || hparams.n_layer() == 48) && n_vocab == 154880) { // GLM 4.7 Lite hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } else { @@ -43,7 +43,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { hparams.f_attn_temp_offset = 0.0f; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 27: type = LLM_TYPE_16B; break; case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; @@ -191,8 +191,7 @@ llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_out_ids = build_inp_out_ids(); - int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < effective_n_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -366,7 +365,7 @@ llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_p Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } - if (il == effective_n_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/deepseek2ocr.cpp b/examples/talk-llama/models/deepseek2ocr.cpp index f9e4c98785c..65d31c31b93 100644 --- a/examples/talk-llama/models/deepseek2ocr.cpp +++ b/examples/talk-llama/models/deepseek2ocr.cpp @@ -14,7 +14,7 @@ void llama_model_deepseek2ocr::load_arch_hparams(llama_model_loader & ml) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/deepseek32.cpp b/examples/talk-llama/models/deepseek32.cpp new file mode 100644 index 00000000000..9a20e2ce907 --- /dev/null +++ b/examples/talk-llama/models/deepseek32.cpp @@ -0,0 +1,499 @@ +#include "models.h" + +#include "llama-kv-cache.h" +#include "llama-kv-cache-dsa.h" + +void llama_model_deepseek32::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-6; // eps for layer norm + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { + case 62: type = LLM_TYPE_685B_A37B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek32::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("DEEPSEEK32 architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_deepseek32::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_deepseek32::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const bool is_mla = hparams.is_mla(); + GGML_ASSERT(is_mla); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); + GGML_UNUSED(n_embd_head_v); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const int64_t n_indexer_head = hparams.indexer_n_head; + const int64_t n_embd_indexer_head = hparams.indexer_head_size; + const int64_t n_embd_indexer_head_rope = hparams.n_rot(); + const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope; + const uint32_t n_indexer_top_k = hparams.indexer_top_k; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation. + // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + + // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor + GGML_ASSERT(ext_factor >= 0.0f); + const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + + // use the original attn_factor to pre-scale the kq_scale + const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_k_dsa * inp_attn_dsa = build_attn_inp_k_dsa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * qr = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(qr, "qr", il); + + qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); + cb(qr, "qr", il); + + ggml_tensor * top_k = nullptr; + + // lightning indexer + { + ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr); + cb(indexer_q, "indexer_q", il); + + // split into {n_embd_indexer_head_rope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_pe = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_rope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, 0); + cb(indexer_q_pe, "indexer_q_pe", il); + + // and {n_embd_indexer_head_nope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_nope = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_nope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, + ggml_row_size(indexer_q->type, n_embd_indexer_head_nope)); + cb(indexer_q_nope, "indexer_q_nope", il); + + indexer_q_pe = ggml_rope_ext(ctx0, indexer_q_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_q_pe, "indexer_q_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, n_head, n_tokens} + indexer_q = ggml_concat(ctx0, indexer_q_pe, indexer_q_nope, 0); + cb(indexer_q, "indexer_q", il); + + ggml_tensor * indexer_k = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_k, cur); + cb(indexer_k, "indexer_k", il); + + indexer_k = build_norm(indexer_k, model.layers[il].indexer_k_norm, model.layers[il].indexer_k_norm_b, LLM_NORM, il); + cb(indexer_k, "indexer_k", il); + + // split into {n_embd_indexer_head_rope, 1, n_tokens} + ggml_tensor * indexer_k_pe = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_rope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, 0); + cb(indexer_k_pe, "indexer_k_pe", il); + + // and {n_embd_indexer_head_nope, 1, n_tokens} + ggml_tensor * indexer_k_nope = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_nope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, + ggml_row_size(indexer_k->type, n_embd_indexer_head_nope)); + cb(indexer_k_nope, "indexer_k_nope", il); + + indexer_k_pe = ggml_rope_ext(ctx0, indexer_k_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_k_pe, "indexer_k_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, 1, n_tokens} + indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0); + cb(indexer_k, "indexer_k", il); + + // perform Hadamard transform on indexer q and k + indexer_q = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_q); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_k); + cb(indexer_k, "indexer_k", il); + + // store indexer keys to KV cache + const auto * mctx_lid = inp_attn_dsa->mctx->get_lid(); + const auto & k_idxs_lid = inp_attn_dsa->get_k_idxs_lid(); + ggml_build_forward_expand(gf, mctx_lid->cpy_k(ctx0, indexer_k, k_idxs_lid, il)); + + // prepare indexer weights + ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur); + cb(indexer_weights, "indexer_weights", il); + + // get cached indexer keys + indexer_k = mctx_lid->get_k(ctx0, il); + + // split the batch into streams if needed + const auto n_stream = indexer_k->ne[3]; + indexer_q = ggml_view_4d(ctx0, indexer_q, indexer_q->ne[0], indexer_q->ne[1], indexer_q->ne[2]/n_stream, n_stream, indexer_q->nb[1], indexer_q->nb[2], indexer_q->nb[3]/n_stream, 0); + indexer_weights = ggml_view_4d(ctx0, indexer_weights, indexer_weights->ne[0], indexer_weights->ne[1]/n_stream, indexer_weights->ne[2], n_stream, indexer_weights->nb[1], indexer_weights->nb[2]/n_stream, indexer_weights->nb[3]/n_stream, 0); + + // calculate indexer kq + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3); + cb(indexer_k, "indexer_k", il); + + ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q); + cb(indexer_kq, "indexer_kq", il); + + // ReLU requires contiguous tensors + indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3)); + cb(indexer_kq, "indexer_kq", il); + + // apply ReLU + ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq); + cb(indexer_score, "indexer_score", il); + + // pre-scale weights to avoid scaling operations on huge indexer_score tensor + indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_embd_indexer_head * n_indexer_head))); + cb(indexer_weights, "indexer_weights", il); + + // multiply scores by indexer weights + indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights); + cb(indexer_score, "indexer_score", il); + + // sum by q n_indexer_head dimension + indexer_score = ggml_sum_rows(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + // permute result to match KQ mask + indexer_score = ggml_cont(ctx0, ggml_permute(ctx0, indexer_score, 2, 1, 0, 3)); + cb(indexer_score, "indexer_score", il); + + // mask indexer scores + ggml_tensor * indexer_kq_mask = inp_attn_dsa->get_kq_mask_lid(); + indexer_score = ggml_add(ctx0, indexer_score, indexer_kq_mask); + cb(indexer_score, "indexer_score", il); + + // get indices of top k indexer scores + uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k; + top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k)); + cb(top_k, "top_k", il); + } + + ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr); + cb(q, "q", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = + ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d( + ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = + ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(q_pe, "q_pe", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(k_pe, "k_pe", il); + + kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); + + // MLA attention + { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn_dsa, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, top_k, kq_scale, il); + } + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il, + nullptr, + model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 4f4c7cac7a8..ad9ce771408 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -398,9 +398,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. - ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); + // K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs]. + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -564,11 +563,8 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( const int64_t D = S_v * S_v * H_v; const int64_t K = cparams.n_rs_seq + 1; - // TODO: remove pad + simplify - ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); - ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); - - ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + // state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output. + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K); if (n_seq_tokens > 1) { cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); } else { @@ -587,21 +583,24 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( cb(output, "attn_output", il); const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); - for (int64_t k_i = 0; k_i < K; ++k_i) { - const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); - ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, - S_v, S_v, H_v, n_seqs, - ggml_row_size(gdn_out->type, S_v), - ggml_row_size(gdn_out->type, S_v * S_v), - ggml_row_size(gdn_out->type, S_v * S_v * H_v), - ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); - ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, - hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - ((size_t) cache_slot * mem_size + kv_head) * row_size); + // op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten + const int64_t n_written = std::min<int64_t>(n_seq_tokens, K); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); - } + // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i) + ggml_tensor * src = ggml_view_3d(ctx0, gdn_out, + D, n_seqs, n_written, + ggml_row_size(gdn_out->type, D), + ggml_row_size(gdn_out->type, state_size_per_snap), + ggml_row_size(gdn_out->type, attn_score_elems)); + + ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all, + D, n_seqs, n_written, + ssm_states_all->nb[1], + (size_t) mem_size * row_size, + (size_t) kv_head * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); return output; } diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 435d27281c6..07d6ab1b7cd 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -8,7 +8,8 @@ void llama_model_dots1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_142B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 12ac6f1ce88..abe737c335a 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -2,8 +2,9 @@ void llama_model_dream::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/eagle3.cpp b/examples/talk-llama/models/eagle3.cpp new file mode 100644 index 00000000000..3321b390515 --- /dev/null +++ b/examples/talk-llama/models/eagle3.cpp @@ -0,0 +1,323 @@ +#include "models.h" + +void llama_model_eagle3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (!ml.get_arr(LLM_KV_TARGET_LAYERS, target_layer_ids, false)) { + throw std::runtime_error("EAGLE3 model requires 'extract_layers' in GGUF metadata"); + } + if (target_layer_ids.size() != 3) { + throw std::runtime_error("EAGLE3 requires exactly 3 entries in 'extract_layers'"); + } + LLAMA_LOG_INFO("%s: EAGLE3 extract_layers = [%d, %d, %d]\n", __func__, + target_layer_ids[0], + target_layer_ids[1], + target_layer_ids[2]); + + uint32_t n_embd_tgt = 0; + + ml.get_key(LLM_KV_TARGET_HIDDEN_SIZE, n_embd_tgt); + LLAMA_LOG_INFO("%s: EAGLE3 n_embd_tgt = %u (draft n_embd = %u)\n", __func__, n_embd_tgt, hparams.n_embd); + + hparams.n_embd_inp_impl = (uint32_t) target_layer_ids.size() * n_embd_tgt; + + // eagle3 norm_before_residual (optional, default false) + // compatible with Readhat eagle3 speculator model + ml.get_key(LLM_KV_NORM_BEFORE_RESIDUAL, hparams.norm_before_residual, false); + if (hparams.norm_before_residual) { + LLAMA_LOG_INFO("%s: EAGLE3gnorm_before_residual = true\n", __func__); + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_eagle3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd_attn_input = 2 * n_embd; + + // Get vocab size from the d2t tensor in the GGUF file (optional - only needed if eagle3 has different vocab_size than target) + // d2t: draft to target vocabulary mapping + int64_t n_draft_vocab = n_vocab; // Default: same as target vocab + const struct ggml_tensor * d2t_meta = ml->get_tensor_meta("d2t"); + if (d2t_meta) { + n_draft_vocab = d2t_meta->ne[0]; // update draft vocab size + d2t = create_tensor(tn(LLM_TENSOR_D2T), {n_draft_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using d2t mapping (draft_vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } else { + d2t = nullptr; // no d2t, use default vocab size + LLAMA_LOG_INFO("%s: EAGLE3 without d2t - sharing same vocab_size with target (vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } + + // Feature fusion layer: projects 3 target layers to draft hidden size + fc = create_tensor(tn(LLM_TENSOR_FC, "weight"), {n_embd_inp, n_embd}, 0); + + // Output layer (uses draft vocab size) + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, TENSOR_NOT_REQUIRED); + + // Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own) + const struct ggml_tensor * tok_embd_meta = ml->get_tensor_meta(tn(LLM_TENSOR_TOKEN_EMBD, "weight").str().c_str()); + if (tok_embd_meta) { + const int64_t n_target_vocab = tok_embd_meta->ne[1]; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_target_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using its own token_embd (vocab = %lld)\n", __func__, (long long)n_target_vocab); + } + + // Single decoder layer + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // input_layernorm: applied to token embeddings + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // eagle3 specific: hidden_norm applied to fused target features + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + + // Attention takes input_embeds_normed + fused_target_normed as input + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_attn_input, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_attn_input, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_attn_input, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // rope_freqs for llama3 rope scaling (optional - only if eagle3 config has rope_scaling) + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_eagle3::build_arch_graph(const llm_graph_params & params) const { + switch (params.gtype) { + case LLM_GRAPH_TYPE_ENCODER: + return std::make_unique<graph<true>>(*this, params); + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + return std::make_unique<graph<false>>(*this, params); + default: + GGML_ABORT("invalid graph type"); + }; +} + +template <> +ggml_tensor * llama_model_eagle3::graph<true>::build_inp_embd_enc() const { + ggml_tensor * cur = nullptr; + + // Input: Target model features (3 layers concatenated: low, mid, high) + // Data will be provided via ubatch->embd in encode_eagle3_features() + auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp()); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,hparams.n_embd_inp(), n_tokens); + ggml_set_input(inp_target->embd); + + cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + +// eagle3 Encoder: processes target model features through feature fusion layer +// Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high +// Output: g_embeddings e.g. [4096, n_tokens] stored in context +template <> +llama_model_eagle3::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = nullptr; + + cur = build_inp_embd_enc(); + + // Feature fusion layer + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); + + // Output: g_embeddings e.g. [4096, n_tokens] + // store in t_h_nextn (same as MTP) so can be read via llama_get_embeddings_nextn(ctx_dft) + ggml_set_output(cur); + res->t_h_nextn = cur; + + ggml_build_forward_expand(gf, cur); +} + +// eagle3 Decoder: processes draft tokens using g_embeddings from encoder +// Input: draft tokens + g_embeddings from encoder +// Output: draft logits +template <> +llama_model_eagle3::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_layer == 1); // eagle3 has only one decoder layer + + ggml_tensor * cur; + ggml_tensor * inpL; + + // eagle3 Decoder receives: + // 1. Token embeddings (e.g.from eagle3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) + // 2. g_embeddings from encoder + auto * tok_embd = model.tok_embd; + if (model.tok_embd == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->tok_embd != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); + tok_embd = model_other->tok_embd; + } + + auto inp = std::make_unique<llm_graph_input_embd>(n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp->embd); + + ggml_tensor * inp_embd = ggml_get_rows(ctx0, tok_embd, inp->tokens); + cb(inp_embd, "inp_embd", -1); + + ggml_tensor * inp_g = inp->embd; + cb(inp_g, "inp_g_embeddings", -1); + + res->add_input(std::move(inp)); + + inpL = inp_g; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + // Single decoder layer (il = 0) + const int il = 0; + { + // Apply input_layernorm to the token embeddings + ggml_tensor * embd_norm = build_norm(inp_embd, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(embd_norm, "embd_norm", il); + + // Apply hidden_norm to inp_g + ggml_tensor * g_norm = build_norm(inp_g, + model.layers[il].attn_norm_2, NULL, + LLM_NORM_RMS, -1); + cb(g_norm, "g_norm", il); + + // norm_before_residual: determines what goes into the residual connection (compatible with Readhat eagle3 speculator model) + // - false (default): use raw inp_g for residual + // - true: use normalized g_norm for residual + // inpL is the concatenated input (normalized inp_embd + normalized inp_g) + ggml_tensor * inpSA = hparams.norm_before_residual ? g_norm : inpL; + + // Concatenate normalized inp_embd and normalized inp_g + cur = ggml_concat(ctx0, embd_norm, g_norm, il); + cb(cur, "concat_embd", il); + + // Self-attention with concatenated input + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // rope freq factors, returns nullptr if not available + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + + // Add residual and update it + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Apply FFN norm to the sum + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + // Output norm with residual + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "eagle3_prenorm", il); + + inpL = cur; + } + + cur = inpL; + + // Output prenorm state (for next token's g_embeddings in autoregressive generation) + ggml_set_output(cur); + res->t_h_nextn = cur; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head - projects to draft vocabulary + // if the draft has no own output projection, inherit the target model's lm_head + auto * output = model.output; + if (output == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->output != nullptr && "EAGLE3 decoder requires an output projection (own or from target model)"); + output = model_other->output; + } + cur = build_lora_mm(output, cur); + + if (model.d2t) { + const int64_t n_draft_vocab = cur->ne[0]; + const int64_t n_outputs = cur->ne[1]; + const int64_t n_vocab = (int64_t) model.vocab.n_tokens(); + + GGML_ASSERT(model.d2t->type == GGML_TYPE_I64); + GGML_ASSERT(model.d2t->ne[0] == n_draft_vocab); + + ggml_tensor * logits = ggml_fill(ctx0, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, n_vocab, n_outputs), -INFINITY); + cur = ggml_set_rows(ctx0, logits, + ggml_reshape_3d(ctx0, cur, 1, n_draft_vocab, n_outputs), + ggml_reshape_3d(ctx0, model.d2t, n_draft_vocab, 1, 1)); + cur = ggml_reshape_2d(ctx0, cur, n_vocab, n_outputs); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index 9b39c605e35..895cf690bd2 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -12,7 +12,7 @@ void llama_model_ernie4_5::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_0_3B; break; case 28: type = LLM_TYPE_21B_A3B; break; case 54: type = LLM_TYPE_300B_A47B; break; diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index ddf13c3028f..0948d7de656 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -3,7 +3,7 @@ void llama_model_eurobert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (hparams.n_layer == 12) { + if (hparams.n_layer() == 12) { type = LLM_TYPE_SMALL; // 0.2B } } diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 76d91982fc5..5aed9379400 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -20,13 +20,12 @@ void llama_model_exaone_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_30B_A3B; break; - case 48: - case 49: type = LLM_TYPE_235B_A22B; break; + case 48: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; } } @@ -50,9 +49,9 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -70,7 +69,7 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end - if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers)) { + if (i < (int) hparams.n_layer_dense_lead || (i >= n_layer)) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); @@ -95,7 +94,7 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); @@ -130,8 +129,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ ggml_tensor * inp_out_ids = build_inp_out_ids(); - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // use RoPE for SWA layers @@ -170,7 +168,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index c7e9960d718..676fb37b5a6 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -3,7 +3,7 @@ void llama_model_exaone::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 499e22dde81..863268abcef 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,7 +1,7 @@ #include "models.h" void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { - if (hparams.n_layer == 64) { // 32B + if (hparams.n_layer() == 64) { // 32B hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; uint32_t swa_period = 4; @@ -15,8 +15,11 @@ void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); - switch (hparams.n_layer) { + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_1_2B; break; case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -37,22 +40,38 @@ void llama_model_exaone4::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { + const bool is_nextn = i >= n_layer; + int flags = 0; + if (is_nextn) { + // NextN/MTP layers are preserved in GGUF but are not executed yet. + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, flags); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + if (!is_nextn) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + } } } diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index 94b65a3c7c9..d6ef2d51986 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -11,9 +11,9 @@ void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), true); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_0_5B; break; case 24: diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index ad546ef2db5..b2ad90b3272 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -3,7 +3,7 @@ void llama_model_falcon::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 60: type = LLM_TYPE_40B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 4e07f5f2bda..80ed3b1a460 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -21,7 +21,7 @@ void llama_model_gemma_embedding::load_arch_hparams(llama_model_loader & ml) { GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_0_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 1519682fdf6..651cd7e64de 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -3,7 +3,7 @@ void llama_model_gemma::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_2B; break; case 28: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index ae3f9ffb530..2fbfb15a94a 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -16,7 +16,7 @@ void llama_model_gemma2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_2B; break; case 42: type = LLM_TYPE_9B; break; case 46: type = LLM_TYPE_27B; break; diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index 63a2b380e71..690194529e3 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -17,7 +17,7 @@ void llama_model_gemma3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_8B; break; // Rnj-1 diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 6ec3a006081..83eb8250aa9 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -6,14 +6,14 @@ void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(swa_period); - hparams.n_layer_kv_from_start = 20; - hparams.f_attention_scale = 1.0f; + hparams.n_layer_kv_from_start = 20; + hparams.f_attention_scale = 1.0f; ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_E2B; break; case 35: type = LLM_TYPE_E4B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma4-assistant.cpp b/examples/talk-llama/models/gemma4-assistant.cpp new file mode 100644 index 00000000000..6378130e79e --- /dev/null +++ b/examples/talk-llama/models/gemma4-assistant.cpp @@ -0,0 +1,203 @@ +#include "models.h" + +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { + hparams.n_embd_inp_impl = hparams.n_embd_out(); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl"); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); +} + +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + if (hparams.n_embd_out() == n_embd) { + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED); + + const int64_t n_embd_backbone = hparams.n_embd_inp(); + nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer_nextn; ++i) { + auto & layer = layers[i]; + + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff = hparams.n_ff(i); + + if (i == 0) { + nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0); + } + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); + + if (!hparams.is_swa(i)) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_backbone = hparams.n_embd_inp(); + + ggml_tensor * inp_tokens; + ggml_tensor * inp_h; + { + auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + inp_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); + cb(inp->embd, "inp_h", -1); + ggml_set_input(inp->embd); + inp_h = inp->embd; + res->t_inp_embd = inp->embd; + + res->add_input(std::move(inp)); + } + + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens); + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); + cb(x, "inp_embd_target", -1); + + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); + cb(xh, "inp_xh", -1); + + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh); + cb(cur, "pre_proj", -1); + + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer_nextn; ++il) { + const bool is_swa = hparams.is_swa(il); + + const int64_t n_embd_head = hparams.n_embd_head_k(il); + const int64_t n_head = hparams.n_head(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_norm, "attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + + if (il == n_layer_nextn - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + ggml_tensor * logits = build_lora_mm(model.output, cur); + cb(logits, "result_output", -1); + res->t_logits = logits; + + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur); + cb(h_next, "h_nextn", -1); + res->t_h_nextn = h_next; + + ggml_build_forward_expand(gf, logits); + ggml_build_forward_expand(gf, h_next); +} diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index 4f9d8b18bc7..6a96979cebd 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -2,12 +2,12 @@ void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); uint32_t n_kv_shared_layers = 0; ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); - hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.n_layer_kv_from_start = hparams.n_layer_all - (int32_t)n_kv_shared_layers; hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); @@ -19,7 +19,7 @@ void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_26B_A4B; break; case 35: type = LLM_TYPE_E2B; break; case 42: type = LLM_TYPE_E4B; break; @@ -142,6 +142,33 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); } +// TODO @ngxson : maybe improve this in the future +class llm_graph_input_logits_bias : public llm_graph_input_i { +public: + llm_graph_input_logits_bias(const llama_vocab & vocab) { + arr.resize(vocab.n_tokens(), 0.0f); + for (llama_token id : vocab.get_suppress_tokens()) { + if (0 <= id && id < (int32_t)vocab.n_tokens()) { + arr[id] = -INFINITY; + } + } + } + virtual ~llm_graph_input_logits_bias() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override { + const int64_t n_vocab = arr.size(); + ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias)); + } + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * logits_bias = nullptr; // F32 [n_vocab] + + std::vector<float> arr; +}; + llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -183,6 +210,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para const float freq_scale_l = model.get_rope_freq_scale(cparams, il); const int n_rot_l = hparams.n_rot(il); + res->t_layer_inp[il] = inpL; + // norm cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); @@ -245,7 +274,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para } // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing - if (il == n_layer - 1 && inp_out_ids) { + // keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token) + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } @@ -345,7 +375,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] // TODO @ngxson : improve this - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); } @@ -376,6 +406,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para model.output_norm, nullptr, LLM_NORM_RMS, -1); + // Expose the post-output-norm hidden state (the LM-head input feature) so that + // MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the + // recurrent h input. This matches the reference (transformers/vLLM/SGLang), + // which feeds the drafter the target's post-final-norm hidden state. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cb(cur, "result_norm", -1); res->t_embd = cur; @@ -388,6 +429,16 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); } + // apply logits bias if needed (e.g. for gemma4_unified patch) + // this is to mirror the suppress_tokens patch on transformers, to avoid model from outputing <image|> and <audio|> tokens (which is a known issue related to the checkpoint) + // TODO: maybe handle this inside the sampling system in the future + if (!model.vocab.get_suppress_tokens().empty()) { + auto inp_bias = std::make_unique<llm_graph_input_logits_bias>(model.vocab); + inp_bias->logits_bias = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, inp_bias->arr.size()); + cur = ggml_add(ctx0, cur, inp_bias->logits_bias); + res->add_input(std::move(inp_bias)); + } + cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm-dsa.cpp b/examples/talk-llama/models/glm-dsa.cpp index af2b55ef563..11d91312def 100644 --- a/examples/talk-llama/models/glm-dsa.cpp +++ b/examples/talk-llama/models/glm-dsa.cpp @@ -33,13 +33,10 @@ void llama_model_glm_dsa::load_arch_hparams(llama_model_loader & ml) { } // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 79: type = LLM_TYPE_744B_A40B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -76,9 +73,9 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; @@ -135,8 +132,8 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); } - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + // NextN/MTP tensors (preserved but unused) - conditionally load for last n_layer_nextn + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 27654b8cba3..d60e47ddf0c 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -20,16 +20,13 @@ void llama_model_glm4_moe::load_arch_hparams(llama_model_loader & ml) { } // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + switch (hparams.n_layer()) { + case 46: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open - case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + case 92: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 default: type = LLM_TYPE_UNKNOWN; } } @@ -54,9 +51,9 @@ void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { // Load ALL tensors including NextN layer to satisfy total tensor count // but only PROCESS up to last layer (skipping final NextN layer) in forward pass - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -116,7 +113,7 @@ void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); @@ -161,8 +158,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -211,7 +207,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index 7c242fed298..b4326c5f210 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -5,13 +5,10 @@ void llama_model_glm4::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); // NextN/MTP parameters (GLM-OCR) - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; @@ -32,9 +29,9 @@ void llama_model_glm4::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -55,7 +52,7 @@ void llama_model_glm4::load_arch_tensors(llama_model_loader &) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); @@ -100,8 +97,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -140,7 +136,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index e2dcc8b1521..45afbccc121 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -2,7 +2,8 @@ void llama_model_gpt2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_SMALL; break; case 24: type = LLM_TYPE_MEDIUM; break; case 36: type = LLM_TYPE_LARGE; break; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 443e35addf2..ed5e8c50da2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -3,7 +3,8 @@ void llama_model_gptneox::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 6: switch (hparams.n_ff()) { case 512: type = LLM_TYPE_14M; break; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 27f6706ea10..eb23095aece 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -19,8 +19,8 @@ void llama_model_granite_hybrid::load_arch_hparams(llama_model_loader & ml) { hparams.rope_finetuned = rope_finetuned; // A layer is recurrent IFF the n_head_kv value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -71,7 +71,7 @@ void llama_model_granite_hybrid::load_arch_tensors(llama_model_loader &) { // norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -158,7 +158,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else { diff --git a/examples/talk-llama/models/granite-moe.cpp b/examples/talk-llama/models/granite-moe.cpp index 0d89bc1f340..115263c418f 100644 --- a/examples/talk-llama/models/granite-moe.cpp +++ b/examples/talk-llama/models/granite-moe.cpp @@ -12,7 +12,7 @@ void llama_model_granite_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index cda4aa231fa..4a75c5ff3cc 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include <sstream> + void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -7,12 +9,33 @@ void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + // Granite4 Vision uses array deepstack_mapping + ml.get_arr(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr, false); + + // Count the unique deepstack input indices + std::unordered_set<uint32_t> unique_deepstack_idxs; + for (const auto val : hparams.deepstack_mapping_arr) { + if (val >= 0) { + unique_deepstack_idxs.insert(val); + } + } + hparams.n_deepstack_layers = unique_deepstack_idxs.size(); + + // Ensure all values are valid (avoid overflow attacks) + for (const auto val : unique_deepstack_idxs) { + if (val > hparams.n_deepstack_layers) { + std::stringstream ss; + ss << "Invalid deepstack index: " << val << " > " << hparams.n_deepstack_layers; + throw std::runtime_error(ss.str()); + } + } + // Granite uses rope_finetuned as a switch for rope, so default to true bool rope_finetuned = true; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes @@ -112,6 +135,20 @@ llama_model_granite::graph::graph( ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + + // Granite Vision 4.1 deepstack: inject the projector stream that + // targets decoder layer `il` before the decoder runs. + // NOTE: skip the first deepstack layer since that's inpL + const auto & deepstack_emb_idx = hparams.deepstack_mapping_arr[il]; + if (il > 0 && deepstack_emb_idx >= 0) { + ggml_tensor * ds = ggml_view_2d(ctx0, + res->t_inp_embd, n_embd, n_tokens, + res->t_inp_embd->nb[1], + deepstack_emb_idx * n_embd * sizeof(float)); + inpL = ggml_add(ctx0, inpL, ds); + cb(inpL, "deepstack_in", il); + } + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 7c46ec1c0f2..42f38af6724 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -26,7 +26,7 @@ void llama_model_grok::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 64: type = LLM_TYPE_314B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index 1cab75adc7f..643a448e59a 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -7,7 +7,7 @@ void llama_model_grovemoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index deb3c9671f3..4d55f5e7f31 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -5,7 +5,7 @@ void llama_model_hunyuan_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_A13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f9ee37a24b6..f6cfdfb9458 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -2,7 +2,8 @@ void llama_model_internlm2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 2ba162605f1..415103ce23a 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -4,7 +4,7 @@ void llama_model_jais::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_3B; break; case 40: type = LLM_TYPE_13B; break; /* TODO: add variants */ diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 8966131441c..8610fcc9f82 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -3,7 +3,7 @@ void llama_model_jais2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; case 68: type = LLM_TYPE_70B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index 84ea63c3136..dba160b014f 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -8,11 +8,11 @@ void llama_model_jamba::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. case 12: // 900M 8x???M case 32: // 51B 16x?B diff --git a/examples/talk-llama/models/jina-bert-v2.cpp b/examples/talk-llama/models/jina-bert-v2.cpp index 4f8866ece4d..86ff1c84d1a 100644 --- a/examples/talk-llama/models/jina-bert-v2.cpp +++ b/examples/talk-llama/models/jina-bert-v2.cpp @@ -4,7 +4,7 @@ void llama_model_jina_bert_v2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); hparams.f_max_alibi_bias = 8.0f; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jina-bert-v3.cpp b/examples/talk-llama/models/jina-bert-v3.cpp index e0527529f56..1c974a6f16c 100644 --- a/examples/talk-llama/models/jina-bert-v3.cpp +++ b/examples/talk-llama/models/jina-bert-v3.cpp @@ -3,7 +3,7 @@ void llama_model_jina_bert_v3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_558M; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index ecffb105496..367f6990d1f 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -14,8 +14,8 @@ void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent } // MoE parameters - Kimi uses moe_intermediate_size = 1024 @@ -25,7 +25,7 @@ void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B default: type = LLM_TYPE_UNKNOWN; } @@ -53,7 +53,7 @@ void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; const int64_t ssm_d_conv = hparams.ssm_d_conv; - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); @@ -285,7 +285,7 @@ llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph ggml_build_forward_expand(gf, cur); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // === KDA Layer (Kimi Delta Attention) with Recurrent State === // Reference: vLLM kda.py const auto * mctx_cur = inp_rs->mctx; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index 29081344b24..97da8a6abb8 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -5,10 +5,13 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } - hparams.n_layer_dense_lead = hparams.n_layer; + + hparams.n_layer_dense_lead = hparams.n_layer(); + switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; @@ -16,10 +19,11 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_swa_impl[il] = !hparams.is_recr_impl[il]; } } } @@ -59,7 +63,7 @@ void llama_model_lfm2::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); @@ -235,8 +239,8 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); - cur = hparams.is_recurrent(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : - build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); + cur = hparams.is_recr(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : + build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/lfm2moe.cpp b/examples/talk-llama/models/lfm2moe.cpp index 12a66c05c7d..490f5c223eb 100644 --- a/examples/talk-llama/models/lfm2moe.cpp +++ b/examples/talk-llama/models/lfm2moe.cpp @@ -9,11 +9,11 @@ void llama_model_lfm2moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_8B_A1B; break; case 40: type = LLM_TYPE_24B_A2B; break; default: type = LLM_TYPE_UNKNOWN; @@ -55,7 +55,7 @@ void llama_model_lfm2moe::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 9722dde9f17..2ae89386447 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -2,11 +2,12 @@ void llama_model_llada_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // diffusion language model uses non-causal attention hparams.causal_attn = false; - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_A1_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 58b2c466e17..87d4259f9a7 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -2,14 +2,16 @@ void llama_model_llada::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } + // Set non-causal attention for diffusion models hparams.causal_attn = false; } diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index cef66d054b0..4bfebc8843c 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -7,13 +7,13 @@ void llama_model_llama::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 8) { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8x7B; break; case 56: type = LLM_TYPE_8x22B; break; default: type = LLM_TYPE_UNKNOWN; } } else { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B case 22: type = LLM_TYPE_1B; break; case 26: type = LLM_TYPE_3B; break; @@ -124,6 +124,8 @@ llama_model_llama::graph<embed>::graph(const llama_model & model, const llm_grap ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 0ff5376d571..7194c72a585 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -8,14 +8,15 @@ void llama_model_llama4::load_arch_hparams(llama_model_loader & ml) { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa == 0) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope + hparams.n_no_rope_layer_step = hparams.n_layer(); // always use rope } else { hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; hparams.n_swa = 8192; hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; - uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 84cfe399027..ae56a26a1f6 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -2,7 +2,8 @@ void llama_model_maincoder::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index 887a1fa509a..0d94e98281c 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -9,7 +9,7 @@ void llama_model_mamba::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: switch (hparams.n_embd) { case 768: type = LLM_TYPE_SMALL; break; diff --git a/examples/talk-llama/models/mamba2.cpp b/examples/talk-llama/models/mamba2.cpp index 3277ca53ec4..c5951cf0f7f 100644 --- a/examples/talk-llama/models/mamba2.cpp +++ b/examples/talk-llama/models/mamba2.cpp @@ -9,7 +9,7 @@ void llama_model_mamba2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: switch (hparams.n_embd) { case 768: type = LLM_TYPE_SMALL; break; diff --git a/examples/talk-llama/models/mellum.cpp b/examples/talk-llama/models/mellum.cpp new file mode 100644 index 00000000000..28823018bc0 --- /dev/null +++ b/examples/talk-llama/models/mellum.cpp @@ -0,0 +1,225 @@ +#include "models.h" + +void llama_model_mellum::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + uint32_t swa_period = 4; + const auto res = ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + if (res) { + hparams.set_swa_pattern(swa_period); + } else { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + } + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer()) { + case 28: type = LLM_TYPE_12B_A2_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mellum::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for Mellum"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for Mellum"); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_mellum::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } + return std::make_unique<graph<false>>(*this, params); +} + +template <bool iswa> +llama_model_mellum::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template struct llama_model_mellum::graph<false>; +template struct llama_model_mellum::graph<true>; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index d0295ec116f..88989160570 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -8,18 +8,18 @@ void llama_model_mimo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); float value_scale = 0.0f; if (ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, value_scale, false) && value_scale != 1.0f) { hparams.f_attn_value_scale = value_scale; } - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_310B_A15B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -34,16 +34,14 @@ void llama_model_mimo2::load_arch_tensors(llama_model_loader &) { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - const uint32_t n_nextn = hparams.nextn_predict_layers; - - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { auto & layer = layers[i]; uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); uint32_t n_head = hparams.n_head(i); // NextN/MTP layers (the last n_nextn blocks) are preserved but disabled pending support - const bool is_nextn = (n_nextn > 0) && (static_cast<uint32_t>(i) >= n_layer - n_nextn); + const bool is_nextn = i >= n_layer; const int skip = is_nextn ? TENSOR_SKIP : 0; create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, skip); @@ -92,10 +90,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param const float v_scale = hparams.f_attn_value_scale; - // The last hparams.nextn_predict_layers blocks are MTP heads, currently inactive - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; uint32_t n_head_l = hparams.n_head(il); @@ -173,7 +168,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param } } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/minicpm.cpp b/examples/talk-llama/models/minicpm.cpp index 966d3af615c..fc3e5b171d5 100644 --- a/examples/talk-llama/models/minicpm.cpp +++ b/examples/talk-llama/models/minicpm.cpp @@ -3,7 +3,7 @@ void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { // Backward-compatible defaults for older MiniCPM GGUFs hparams.f_embedding_scale = 12.0f; - hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer())); hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -16,7 +16,7 @@ void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { // MiniCPM uses rope by default, unlike Granite which uses it as a switch hparams.rope_finetuned = true; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 52: type = LLM_TYPE_1B; break; case 40: type = LLM_TYPE_2B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index 1ffc54fa7c6..e011b1ff0a8 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -5,7 +5,7 @@ void llama_model_minicpm3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 22e291d73a3..b25435e4d97 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -5,7 +5,7 @@ void llama_model_minimax_m2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_230B_A10B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 4e6ebef82cb..9a8e3f9a50b 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -18,7 +18,7 @@ void llama_model_mistral3::load_arch_hparams(llama_model_loader & ml) { } } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_3B; break; case 34: type = LLM_TYPE_8B; break; case 40: type = LLM_TYPE_14B; break; @@ -177,9 +177,9 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -200,7 +200,11 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa LLM_FFN_SILU, true, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 7e551eb965b..ee3aff07b9a 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) + // use the ggml_gated_delta_net fused operator (K=1; state has shape [S_v, S_v, H_v, n_seqs]) std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -186,6 +186,19 @@ struct llama_model_maincoder : public llama_model_base { }; +struct llama_model_talkie : public llama_model_base { + llama_model_talkie(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_deci : public llama_model_base { llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; @@ -398,6 +411,18 @@ struct llama_model_stablelm : public llama_model_base { std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_mellum : public llama_model_base { + llama_model_mellum(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; struct llama_model_qwen : public llama_model_base { llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} @@ -797,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base { }; +struct llama_model_gemma4_assistant : public llama_model_base { + llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_gemma_embedding : public llama_model_base { llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; @@ -1017,6 +1055,19 @@ struct llama_model_deepseek2 : public llama_model_base { }; +struct llama_model_deepseek32 : public llama_model_base { + llama_model_deepseek32(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_deepseek2ocr : public llama_model_base { llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; @@ -1038,6 +1089,21 @@ struct llama_model_glm_dsa : public llama_model_base { std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_eagle3 : public llama_model_base { + llama_model_eagle3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool is_enc> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + + ggml_tensor * build_inp_embd_enc() const; + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + struct llama_model_mistral4 : public llama_model_deepseek2 { llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {} @@ -1887,5 +1953,9 @@ struct llama_model_step35 : public llama_model_base { graph(const llama_model & model, const llm_graph_params & params); }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index e9b79ffc6dc..f3e9407e012 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -14,7 +14,15 @@ void llama_model_modern_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + // Some ModernBert derivatives (e.g. IBM Granite Embedding 97m R2) use + // SiLU/SwiGLU in the FFN instead of the default GELU/GeGLU. + hparams.llm_ffn_op = LLM_FFN_GEGLU; + std::string hidden_act; + if (ml.get_key(LLM_KV_HIDDEN_ACT, hidden_act, false)) { + hparams.llm_ffn_op = llm_ffn_op_type_from_string(hidden_act, LLM_FFN_GEGLU); + } + + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_47M; break; // granite-embedding-small case 22: @@ -144,7 +152,8 @@ llama_model_modern_bert::graph::graph(const llama_model & model, const llm_graph NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + hparams.llm_ffn_op, + LLM_FFN_SEQ, il); // attentions bypass the intermediate layer cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index 0229d20ed36..d094fd9f80b 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -5,7 +5,7 @@ void llama_model_mpt::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_30B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index a82f9c170b4..a456269347b 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -9,8 +9,8 @@ void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { // A layer is recurrent IFF the n_head_kv value is set to 0 and // the n_ff value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -22,7 +22,7 @@ void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; case 88: type = LLM_TYPE_120B_A12B; break; @@ -62,7 +62,7 @@ void llama_model_nemotron_h::load_arch_tensors(llama_model_loader &) { // all blocks use the attn norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -143,7 +143,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else if (hparams.n_ff(il) == 0) { diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 5d4a3b5c69e..6e2bd9a33ca 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -2,7 +2,8 @@ void llama_model_nemotron::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index f00d6eddfc9..4a08d7abd40 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -3,7 +3,7 @@ void llama_model_neo_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (hparams.n_layer == 28) { + if (hparams.n_layer() == 28) { type = LLM_TYPE_250M; } } diff --git a/examples/talk-llama/models/nomic-bert-moe.cpp b/examples/talk-llama/models/nomic-bert-moe.cpp index a17abe2c269..da4b62919bb 100644 --- a/examples/talk-llama/models/nomic-bert-moe.cpp +++ b/examples/talk-llama/models/nomic-bert-moe.cpp @@ -4,7 +4,7 @@ void llama_model_nomic_bert_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { if (arch == LLM_ARCH_NOMIC_BERT) { type = LLM_TYPE_137M; } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { diff --git a/examples/talk-llama/models/nomic-bert.cpp b/examples/talk-llama/models/nomic-bert.cpp index 5a8a5584457..e7fc72286a6 100644 --- a/examples/talk-llama/models/nomic-bert.cpp +++ b/examples/talk-llama/models/nomic-bert.cpp @@ -4,7 +4,7 @@ void llama_model_nomic_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { if (arch == LLM_ARCH_NOMIC_BERT) { type = LLM_TYPE_137M; } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index cfcf17bcb03..9f7a2ba60ef 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -4,7 +4,7 @@ void llama_model_olmo::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 22: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 7cc262f5504..cb52cdef720 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -17,7 +17,7 @@ void llama_model_olmo2::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 7976ae44a51..1e2baeb207f 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -2,7 +2,8 @@ void llama_model_olmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_A1_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 15b6c8c1205..6d74f9c7e6e 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -14,7 +14,7 @@ void llama_model_openai_moe::load_arch_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_20B; break; case 36: type = LLM_TYPE_120B; break; default: type = LLM_TYPE_UNKNOWN; @@ -75,6 +75,8 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 9f76350fd4d..13120bd3236 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -3,12 +3,12 @@ void llama_model_openelm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_270M; break; - case 20: type = LLM_TYPE_450M; break; - case 28: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index bcb4bbba4b1..863a2822269 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -3,7 +3,7 @@ void llama_model_orion::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_14B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 7593f879b24..90f05c088c1 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -2,7 +2,8 @@ void llama_model_pangu_embed::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 8f3ed5f7b7d..81b1ad12cc0 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -3,7 +3,7 @@ void llama_model_phi2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index f8a4a4d5aa5..716ff814cc1 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -3,7 +3,7 @@ void llama_model_phi3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_14B; break; diff --git a/examples/talk-llama/models/phimoe.cpp b/examples/talk-llama/models/phimoe.cpp index 4575d6139cf..c332553bc7d 100644 --- a/examples/talk-llama/models/phimoe.cpp +++ b/examples/talk-llama/models/phimoe.cpp @@ -3,7 +3,7 @@ void llama_model_phimoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_16x3_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index c7ed1211c31..246144519e4 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -3,7 +3,7 @@ void llama_model_plamo::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index b713889fe72..0b81513c368 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -11,11 +11,15 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); + + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; case 32: if (hparams.n_embd == 2048) { @@ -54,7 +58,7 @@ void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - bool is_mamba_layer = hparams.is_recurrent(i); + bool is_mamba_layer = hparams.is_recr(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -128,7 +132,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // check if this layer is Mamba or Attention - const bool is_mamba_layer = hparams.is_recurrent(il); + const bool is_mamba_layer = hparams.is_recr(il); if (is_mamba_layer) { // PLaMo-2 Mamba layer @@ -273,7 +277,7 @@ ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_inpu GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - GGML_ASSERT(d_inner % n_head == 0); + GGML_ASSERT(d_inner % n_heads == 0); GGML_ASSERT(n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 29f3e803d68..16d0b1dcef7 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -13,7 +13,7 @@ void llama_model_plamo3::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_2B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index ce050919e6a..8ca325f5e2c 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -3,7 +3,8 @@ void llama_model_plm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 00467dbad7d..1f5dff3843c 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -3,7 +3,7 @@ void llama_model_qwen::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index a5147460bae..e9c2ea80a6b 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -2,7 +2,8 @@ void llama_model_qwen2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; case 32: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7cb03859deb..e831ed11aad 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -5,7 +5,8 @@ void llama_model_qwen2moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_A2_7B; break; case 28: type = LLM_TYPE_57B_A14B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 41b97fed956..f4b2a2aebe0 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -2,7 +2,8 @@ void llama_model_qwen3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; case 40: type = LLM_TYPE_14B; break; @@ -68,6 +69,8 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 04ecc18fcdc..6783d98ec20 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -13,21 +13,20 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -38,9 +37,7 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - const bool mtp_only = (hparams.nextn_predict_layers > 0) && - (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); @@ -69,7 +66,7 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -121,10 +118,10 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); }; - for (int i = 0; i < (int) n_main; ++i) { + for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - for (int i = (int) n_main; i < n_layer; ++i) { + for (int i = n_layer; i < n_layer_all; ++i) { load_block_mtp(i); } } @@ -158,8 +155,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_out_ids = build_inp_out_ids(); // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. - const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -168,7 +164,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { @@ -176,8 +172,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -208,16 +204,15 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; - if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + if (!cparams.embeddings_nextn_masked && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); } - // Final norm - cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); res->t_embd = cur; @@ -490,15 +485,15 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons // LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35 MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35 MTP currently only supports a single MTP block"); const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); // hparams.n_layer includes both main model layers and MTP layers. The MTP // layer is stored immediately after the main layers in model.layers[]. - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const int il = hparams.n_layer(); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -508,28 +503,41 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_input(inp->tokens); - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } cb(tok_embd, "mtp_tok_embd", il); + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + res->add_input(std::move(inp)); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - auto * inp_attn = build_attn_inp_kv(); - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); cb(h_norm, "mtp_hnorm", il); ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); @@ -611,18 +619,16 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "mtp_post_ffn", il); - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm ? layer.nextn.shared_head_norm : model.output_norm; GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "mtp_shared_head_norm", -1); ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index dc24f6ed537..eb5e9a406a1 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -16,21 +16,20 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -41,9 +40,7 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - const bool mtp_only = (hparams.nextn_predict_layers > 0) && - (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); @@ -75,7 +72,7 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -144,10 +141,10 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); }; - for (int i = 0; i < (int) n_main; ++i) { + for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - for (int i = (int) n_main; i < n_layer; ++i) { + for (int i = n_layer; i < n_layer_all; ++i) { load_block_mtp(i); } } @@ -181,8 +178,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_out_ids = build_inp_out_ids(); // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. - const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -191,7 +187,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { @@ -199,7 +195,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -231,16 +227,16 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; + // post-norm hidden state feeds both the LM head and the MTP seed below + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; - if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + if (!cparams.embeddings_nextn_masked && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); } - // Final norm - cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); res->t_embd = cur; @@ -554,13 +550,13 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c // LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35MOE MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const int il = hparams.n_layer(); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -571,29 +567,41 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_input(inp->tokens); - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } cb(tok_embd, "mtp_tok_embd", il); + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + res->add_input(std::move(inp)); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - auto * inp_attn = build_attn_inp_kv(); + auto * inp_attn = build_attn_inp_kv(); - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); cb(h_norm, "mtp_hnorm", il); ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); @@ -708,17 +716,16 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "mtp_post_ffn", il); - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm ? layer.nextn.shared_head_norm : model.output_norm; GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn= cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "mtp_shared_head_norm", -1); ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index a4f8e1379c9..6f6df5390e3 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -1,10 +1,10 @@ #include "models.h" void llama_model_qwen3moe::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; @@ -78,6 +78,8 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index 1d873427db5..97200a44072 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -14,15 +14,15 @@ void llama_model_qwen3next::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - { + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_80B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -68,7 +68,7 @@ void llama_model_qwen3next::load_arch_tensors(llama_model_loader &) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { // Attention layers create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); @@ -129,7 +129,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 5defd893944..724d6140d19 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -4,7 +4,8 @@ void llama_model_qwen3vl::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_1_7B; break; case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; case 64: type = LLM_TYPE_32B; break; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index 5b77df57122..7c41592f772 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -5,7 +5,8 @@ void llama_model_qwen3vlmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index bf3949a9092..a46c358fa68 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -2,7 +2,8 @@ void llama_model_refact::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index ca8e009615e..fc276ce591b 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -2,12 +2,13 @@ void llama_model_rnd1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } + // Set non-causal attention for diffusion models hparams.causal_attn = false; } diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index ba2a9dfa0db..0b5013dc758 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -9,7 +9,7 @@ void llama_model_rwkv6::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 566b8cdcb54..6c7db514435 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -9,7 +9,7 @@ void llama_model_rwkv6qwen2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 7574b252621..67c51f5b59c 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -10,7 +10,7 @@ void llama_model_rwkv7::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: switch (hparams.n_embd) { case 768: type = LLM_TYPE_190M; break; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 806cba574be..57de881a091 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -2,7 +2,8 @@ void llama_model_seed_oss::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 64: type = LLM_TYPE_36B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 4231cccc666..a8e3d957f1f 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -15,14 +15,14 @@ void llama_model_smallthinker::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; + hparams.n_no_rope_layer_step = hparams.n_layer(); } ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_4B; break; case 52: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 90e7d473eaf..c67d967b204 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -4,7 +4,7 @@ void llama_model_smollm3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); hparams.n_no_rope_layer_step = 4; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 4da7f7aefcf..bf6087b8796 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -3,7 +3,7 @@ void llama_model_stablelm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_12B; break; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index e131af058bc..f73a88fd4e9 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -2,7 +2,8 @@ void llama_model_starcoder::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 36: type = LLM_TYPE_3B; break; case 42: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 9c207c02885..b81b469374a 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -2,7 +2,8 @@ void llama_model_starcoder2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_3B; break; case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_15B; break; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index 3b68e68707a..e2218c58704 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -22,24 +22,39 @@ void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); - switch (hparams.n_layer) { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer(), false); + + // NextN/MTP (Step3p5): extra decoder block appended beyond the main stack. + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { case 45: type = LLM_TYPE_196B_A11B; break; default: type = LLM_TYPE_UNKNOWN; } } -void llama_model_step35::load_arch_tensors(llama_model_loader &) { +void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + // Trunk-only: the GGUF declares MTP layers in metadata but the actual MTP + // tensors live in a separate file (e.g. user split target/draft). Mark + // MTP tensors NOT_REQUIRED so the trunk loads cleanly. + const std::string mtp_probe = "blk." + std::to_string(n_layer) + ".nextn.eh_proj.weight"; + const bool trunk_only = (hparams.n_layer_nextn > 0) && (ml.get_weight(mtp_probe.c_str()) == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + const int mtp_flags = trunk_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, trunk_flags); // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. @@ -51,14 +66,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { n_rot_max = n_rot; } - for (int i = 0; i < n_layer; ++i) { + auto load_block_trunk = [&](int i, int flags) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -70,13 +85,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags); // head-wise attention gate (Step35 self_attn.g_proj) layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense MLP (leading dense blocks) layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -95,10 +110,86 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + }; + + auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the + // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). + // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. + // + // Only the FIRST MTP block (i == n_main) is required for the + // single-block MTP runtime; trailing MTP blocks are always tolerated + // as missing so pruned GGUFs (block 0 only) load cleanly. Override + // mtp_flags to NOT_REQUIRED for those. + const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + + // dense MLP (leading dense blocks) — present if the MTP block isn't MoE + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < n_layer; ++i) { + load_block_trunk(i, trunk_flags); + } + // Only the first MTP block (i == n_main) is required at runtime — the + // single-block-MTP graph in build_arch_graph always uses that one. + // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with + // all MTP layers still works) but tolerated when absent via the pruning + // path. See scripts/prune_step35_extra_mtp.py for the pruner. + for (int i = n_layer; i < n_layer_all; ++i) { + load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); } } std::unique_ptr<llm_graph_context> llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique<graph_mtp>(*this, params); + } return std::make_unique<graph>(*this, params); } @@ -111,6 +202,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -198,8 +290,8 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "attn_proj", il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -257,6 +349,13 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cur = inpL; + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -267,3 +366,192 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Step3p5 (MoE) +llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); + + // Single-block MTP only: always run the first trained MTP block (Qwen + // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to + // be a much deeper refactor than this PR justifies; the trailing MTP + // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just + // block 0) also work — see load_arch_tensors below and + // scripts/prune_step35_extra_mtp.py. + const int il = hparams.n_layer(); + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + // mtp_block: full Step3p5 decoder layer (attention with optional head-wise gate, then MoE/dense FFN) + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Qcur, "mtp_Qcur", il); + cb(Kcur, "mtp_Kcur", il); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + if (layer.attn_q_norm) { + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + } + if (layer.attn_k_norm) { + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + } + + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "mtp_Qcur_pos", il); + cb(Kcur, "mtp_Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "mtp_attn_out", il); + + // head-wise attention gate: sigmoid(g_proj(x)) + if (layer.wqkv_gate) { + ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "mtp_attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "mtp_attn_gate_sigmoid", il); + + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "mtp_attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "mtp_attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "mtp_attn_gated", il); + } + + cur = build_lora_mm(layer.wo, attn_out, layer.wo_s); + cb(cur, "mtp_attn_proj", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_ffn_norm", il); + + // FFN: dense MLP or MoE (mirrors trunk path) + if (layer.ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + layer.ffn_up, layer.ffn_up_b, nullptr, + layer.ffn_gate, layer.ffn_gate_b, nullptr, + layer.ffn_down, layer.ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "mtp_ffn_moe_out", il); + + ggml_tensor * sh_out = build_ffn(cur, + layer.ffn_up_shexp, nullptr, nullptr, + layer.ffn_gate_shexp, nullptr, nullptr, + layer.ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "mtp_ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "mtp_ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "STEP35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "STEP35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 73e32741406..b0e3f062572 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -9,10 +9,10 @@ void llama_model_t5::load_arch_hparams(llama_model_loader & ml) { hparams.dec_start_token_id = dec_start_token_id; } - hparams.dec_n_layer = hparams.n_layer; + hparams.dec_n_layer = hparams.n_layer(); ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 6: type = LLM_TYPE_60M; break; // t5-small case 8: type = LLM_TYPE_80M; break; // flan-t5-small case 12: diff --git a/examples/talk-llama/models/talkie.cpp b/examples/talk-llama/models/talkie.cpp new file mode 100644 index 00000000000..393e8f65bf4 --- /dev/null +++ b/examples/talk-llama/models/talkie.cpp @@ -0,0 +1,149 @@ +#include "models.h" + +void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_talkie::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // no k gain + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {1, n_head}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_talkie::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_talkie::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(inpL, "inp_norm", -1); + + ggml_tensor * embd_skip = inpL; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + ggml_tensor * inp_skip = embd_skip; + + cur = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + // reference applies qknorm after rope + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + Kcur = build_norm(Kcur, nullptr, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inp_skip = ggml_get_rows(ctx0, inp_skip, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + ggml_tensor * skip = ggml_mul(ctx0, inp_skip, model.layers[il].out_scale); + cb(skip, "embd_skip", il); + + cur = ggml_add(ctx0, cur, skip); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index d6d1c7a2e5d..3135001293a 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -2,7 +2,8 @@ void llama_model_xverse::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; case 80: type = LLM_TYPE_65B; break; diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index ad5b4379be0..cb73b1cd115 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 13) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_MINOR 15) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index b6f73739809..2924fdbe988 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -381,11 +381,15 @@ extern "C" { // - most tensors have n_segments == 1 and a contiguous slice of the tensor data // - some tensors have an inhomogenenous data layout along the split axis, // those tensors are divided into segments which are each individually split across devices - // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis, - // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1], + // - ne has one entry per segment and device and that segment repeats nr times, + // in total when accounting for repetitions the segments add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0_r0, seg0_dev1_r0, seg0_dev0_r1, seg0_dev1_r1, seg1_dev0_r0, seg1_dev1_r0], // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments - // that each need to be split individually across devices so that each device gets a slice of Q, K, and V + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V, + // the Q matrix can be larger than the K and V matrices so this can either be expressed as 3 segments or as 2 segments + // where the segment for K/V repeats twice int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t nr[16]; uint32_t n_segments; }; diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 6fcf5a43393..5ad121ae57f 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,10 +8,10 @@ extern "C" { #define RPC_PROTO_MAJOR_VERSION 4 #define RPC_PROTO_MINOR_VERSION 0 -#define RPC_PROTO_PATCH_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 1 #ifdef __cplusplus -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); #endif #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f6725265504..d6807b6dd47 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -535,6 +535,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_3D, + GGML_OP_COL2IM_1D, GGML_OP_CONV_2D, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, @@ -2007,6 +2008,16 @@ extern "C" { int d1, // dilation dimension 1 bool is_2D); + // col2im_1d: scatter-add GEMM columns back to 1D signal + // a: [K*OC, T_in] (columns from matmul, K = a->ne[0]/OC) + // result: [T_out, OC] where T_out = (T_in - 1)*s0 + K - 2*p0 + GGML_API struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, // columns [K*OC, T_in] + int s0, // stride + int oc, // output channels + int p0); // padding to crop from both sides + GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -2542,10 +2553,16 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 // - // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): - // K == 1: output carries the final state only. - // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) - // per-token snapshots into the trailing slots + // tensor shapes (S_k == S_v, H_v % H_k == 0): + // q, k : [S_k, H_k, n_tokens, n_seqs] + // v : [S_v, H_v, n_tokens, n_seqs] + // g : [1, H_v, n_tokens, n_seqs] (scalar gate) or [S_v, H_v, n_tokens, n_seqs] (KDA) + // beta : [1, H_v, n_tokens, n_seqs] + // state : [S_v, S_v, H_v, n_seqs] -- initial recurrent state s0 + // + // the output packs the attention scores [S_v, H_v, n_tokens, n_seqs] followed by K state + // snapshots, most-recent first (slot 0 = final state, slot s = state s tokens back). K == 1 + // keeps only the final state; when n_tokens < K only slots 0..n_tokens-1 are written. GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, @@ -2553,7 +2570,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state); + struct ggml_tensor * state, + int64_t K); // custom operators diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index d0d64523b4a..0a36f099000 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -487,6 +487,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { + // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way. + // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there. + // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results. const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; @@ -497,11 +500,11 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( for (size_t j = 0; j < n_bufs; j++) { int64_t sum_a = 0; for (size_t s = 0; s < a.n_segments; s++) { - sum_a += a.ne[s*n_bufs + j]; + sum_a += a.ne[s*n_bufs + j] * a.nr[s]; } int64_t sum_b = 0; for (size_t s = 0; s < b.n_segments; s++) { - sum_b += b.ne[s*n_bufs + j]; + sum_b += b.ne[s*n_bufs + j] * b.nr[s]; } if (sum_a != sum_b) { return false; @@ -511,7 +514,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( }; auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { - ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}; for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { continue; @@ -519,15 +522,15 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { ret = src_ss[i]; } else if (!split_states_equal(src_ss[i], ret)) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; break; } } if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); return ret; @@ -571,42 +574,24 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { ggml_backend_meta_split_state ret = src_ss[0]; ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.nr[0] = 1; ret.n_segments = 1; return ret; } if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - ggml_backend_meta_split_state ret = src_ss[1]; - ret.n_segments = 1; - return ret; + return src_ss[1]; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); - return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1}; + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1}; } GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; - }; - - auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { - if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { - int64_t ne_split_src = tensor->src[0]->ne[0]; - for (int dim = 1; dim <= src_ss[0].axis; dim++) { - ne_split_src *= tensor->src[0]->ne[dim]; - } - int64_t ne_split_dst = 1; - for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { - ne_split_dst *= tensor->ne[dim]; - if (ne_split_dst == ne_split_src) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } - } - } - return handle_generic(src_ss, /*scalar_only =*/ false); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { @@ -615,33 +600,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0])); - if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) { - return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1}; } - std::vector<int64_t> base_ne_in; - base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis); - { - base_ne_in.push_back(1); - int dim = 0; - for (; dim <= src_ss[0].axis; dim++) { - base_ne_in[0] *= tensor->src[0]->ne[dim]; - } - for (; dim <= GGML_MAX_DIMS; dim++) { - base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]); - } + int64_t base_ne_in = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + base_ne_in *= tensor->src[0]->ne[dim]; } + base_ne_in /= src_ss[0].nr[0]; int64_t base_ne_out = 1; for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; - for (const int64_t & bni : base_ne_in) { - if (bni == base_ne_out_next) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } + if (base_ne_out_next % base_ne_in == 0) { + return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1}; } - if (base_ne_out_next > base_ne_in[0]) { - GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); - return {ggml_backend_meta_split_axis(dim + 1), {0}, 1}; + if (base_ne_out_next > base_ne_in) { + GGML_ASSERT(src_ss[0].n_segments == 1); + GGML_ASSERT(src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } base_ne_out = base_ne_out_next; } @@ -653,11 +630,18 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; + auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + return handle_reshape(src_ss); + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { return handle_reshape(src_ss); @@ -681,7 +665,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } } GGML_ABORT("fatal error"); @@ -690,7 +674,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( return src_ss[0]; } GGML_ABORT("view of permuted tensor not implemented"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { @@ -699,7 +683,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_MIRRORED: case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { @@ -707,7 +692,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -716,7 +701,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( switch (src_ss[0].axis) { case GGML_BACKEND_SPLIT_AXIS_0: case GGML_BACKEND_SPLIT_AXIS_1: { - return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: @@ -726,7 +712,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -764,16 +750,16 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; }; auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == src_ss[1].axis) { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; } } return handle_generic(src_ss, /*scalar_only =*/ false); @@ -781,8 +767,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { return src_ss[0]; } GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -790,15 +776,15 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, - // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2, + // so a head-aligned split on the input cache lands on axis 2 here. GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; }; auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { if (ggml_nelements(tensor) == 0) { - return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); @@ -807,19 +793,21 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; int64_t ne_sum = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - GGML_ASSERT(ret.ne[sj] % granularity == 0); - ne_sum += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0); + ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); } return ret; } - std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}); + std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}); for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { - src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; continue; } src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); @@ -829,7 +817,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_split_state split_state; switch (tensor->op) { case GGML_OP_NONE: { - split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } break; case GGML_OP_DUP: { split_state = handle_generic(src_ss, /*scalar_only =*/ true); @@ -1016,7 +1004,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } break; default: { GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); - split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } break; } if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { @@ -1034,23 +1022,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( split_state.ne[s*n_bufs + j] = 0; } for (size_t s = 0; s < src_ss[i].n_segments; s++) { - split_state.ne[j] += src_ss[i].ne[s*n_bufs + j]; + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } split_state.ne[j] *= tensor->ne[split_state.axis]; if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { - GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0); - split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis]; + const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0]; + GGML_ASSERT(split_state.ne[j] % div == 0); + split_state.ne[j] /= div; } } } else { + GGML_ASSERT(split_state.n_segments == 1); for (size_t j = 0; j < n_bufs; j++) { + // Assert that ratio is consistent: int64_t sum = 0; for (size_t s = 0; s < src_ss[i].n_segments; s++) { - sum += src_ss[i].ne[s*n_bufs + j]; + sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } - // Assert that ratio is consistent: - GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] - == sum * tensor->ne[split_state.axis]); + GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); } } first_src_split_by_axis = false; @@ -1080,13 +1070,14 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( srcs_info += ", "; } const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + GGML_ASSERT(split_state.n_segments == 1); const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); std::string ne_info; for (size_t j = 0; j < n_bufs; j++) { if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(split_state.ne[j]); + ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]); } srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; } @@ -1095,7 +1086,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]); + const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first; + ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]); } GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); @@ -1107,8 +1099,10 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( #ifndef NDEBUG if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { int64_t ne_ret = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - ne_ret += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } assert(ne_ret == tensor->ne[int(ret.axis)]); } @@ -1155,7 +1149,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); ne[split_dim] = 0; for (size_t s = 0; s < split_state.n_segments; s++) { - ne[split_dim] += split_state.ne[s*n_simple_bufs + j]; + ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s]; } for (int i = 0; i < GGML_MAX_DIMS; i++) { if (tensor->nb[i] > tensor->nb[split_dim]) { @@ -1229,7 +1223,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m for (size_t j = 0; j < n_simple_bufs; j++) { int64_t ne_sum = 0; for (size_t s = 0; s < split_state_src.n_segments; s++) { - ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s]; } if (ne_sum == 0) { simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; @@ -1255,8 +1249,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1267,24 +1262,26 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1292,22 +1289,24 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1365,8 +1364,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1377,24 +1377,26 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1402,22 +1404,24 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1675,6 +1679,7 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: @@ -1719,6 +1724,7 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: @@ -2076,6 +2082,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, node_zero->src[0] = node; ggml_set_op_params_f32(node_zero, 0, 0.0f); node_zero->data = node->data; + node_zero->buffer = node->buffer; node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; step_cgraphs[j] = get_cgraph_aux(); diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index 74e0c086c6d..9c43da6cf89 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -977,6 +977,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc); *s = sumf; + +#elif defined(__loongarch_sx) + + __m128 acc = (__m128)__lsx_vldi(0); + + for (; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d); + const __m128i qx_0 = __lsx_vld((const __m128i *)x[ib].qs, 0); + const __m128i qx_1 = __lsx_vld((const __m128i *)x[ib].qs + 1, 0); + const __m128i qy_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); + const __m128i qy_1 = __lsx_vld((const __m128i *)y[ib].qs + 1, 0); + + const __m128i p16_0 = lsx_maddubs_h(qx_0, qy_0); + const __m128i p16_1 = lsx_maddubs_h(qx_1, qy_1); + + // Sum int16 pairs → int32 + const __m128i s_0 = __lsx_vaddwev_w_h(p16_0, p16_1); + const __m128i s_1 = __lsx_vaddwod_w_h(p16_0, p16_1); + + const __m128 q = __lsx_vffint_s_w(__lsx_vadd_w(s_0, s_1)); + acc = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(d), q, acc); + } + + __m128 res = lsx_hadd_s(acc, acc); + res = lsx_hadd_s(res, res); + sumf = ((v4f32)res)[0]; + + *s = sumf; + #else UNUSED(nb); UNUSED(ib); @@ -1443,6 +1472,99 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = hsum_float_8(acc); +#elif defined(__loongarch_sx) + + const __m128i m32s = __lsx_vreplgr2vr_b(32); + + __m128 acc_0 = (__m128)__lsx_vldi(0); + __m128 acc_1 = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i scale_i8 = __lsx_vld(x[i].scales, 0); + const __m128i scales_lo = __lsx_vsllwil_h_b(scale_i8, 0); + const __m128i scales_hi = __lsx_vsllwil_h_b(__lsx_vbsrl_v(scale_i8, 8), 0); + + __m128i sumi_0 = __lsx_vldi(0); + __m128i sumi_1 = __lsx_vldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + const __m128i q4bitsH_1 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + + const __m128i q4h_0 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3), 4); + const __m128i q4h_1 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3), 4); + const __m128i q4h_2 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3 << 2), 2); + const __m128i q4h_3 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3 << 2), 2); + const __m128i q4h_4 = __lsx_vandi_b(q4bitsH_0, 3 << 4); + const __m128i q4h_5 = __lsx_vandi_b(q4bitsH_1, 3 << 4); + const __m128i q4h_6 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_0, 3 << 6), 2); + const __m128i q4h_7 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_1, 3 << 6), 2); + + const __m128i q4bits1_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits1_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + + const __m128i q4_0 = __lsx_vor_v(__lsx_vandi_b(q4bits1_0, 0xf), q4h_0); + const __m128i q4_1 = __lsx_vor_v(__lsx_vandi_b(q4bits1_1, 0xf), q4h_1); + const __m128i q4_2 = __lsx_vor_v(__lsx_vandi_b(q4bits2_0, 0xf), q4h_2); + const __m128i q4_3 = __lsx_vor_v(__lsx_vandi_b(q4bits2_1, 0xf), q4h_3); + const __m128i q4_4 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_0, 4), q4h_4); + const __m128i q4_5 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_1, 4), q4h_5); + const __m128i q4_6 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_0, 4), q4h_6); + const __m128i q4_7 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_1, 4), q4h_7); + + const __m128i q8_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_2 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_3 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_4 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_5 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_6 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_7 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + + __m128i p16_0 = lsx_maddubs_h(__lsx_vsub_b(q4_0, m32s), q8_0); + __m128i p16_1 = lsx_maddubs_h(__lsx_vsub_b(q4_1, m32s), q8_1); + __m128i p16_2 = lsx_maddubs_h(__lsx_vsub_b(q4_2, m32s), q8_2); + __m128i p16_3 = lsx_maddubs_h(__lsx_vsub_b(q4_3, m32s), q8_3); + __m128i p16_4 = lsx_maddubs_h(__lsx_vsub_b(q4_4, m32s), q8_4); + __m128i p16_5 = lsx_maddubs_h(__lsx_vsub_b(q4_5, m32s), q8_5); + __m128i p16_6 = lsx_maddubs_h(__lsx_vsub_b(q4_6, m32s), q8_6); + __m128i p16_7 = lsx_maddubs_h(__lsx_vsub_b(q4_7, m32s), q8_7); + + const __m128i sc_vec = j == 0 ? scales_lo : scales_hi; + + p16_0 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 0), p16_0); + p16_1 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 1), p16_1); + p16_2 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 2), p16_2); + p16_3 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 3), p16_3); + p16_4 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 4), p16_4); + p16_5 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 5), p16_5); + p16_6 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 6), p16_6); + p16_7 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 7), p16_7); + + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_0, p16_2)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_1, p16_3)); + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_4, p16_6)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_5, p16_7)); + } + + __m128 p_0 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_0)); + __m128 p_1 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_1)); + acc_0 = __lsx_vfadd_s(p_0, acc_0); + acc_1 = __lsx_vfadd_s(p_1, acc_1); + } + + *s = hsum_float_4x4(acc_0, acc_1, (__m128)__lsx_vldi(0), (__m128)__lsx_vldi(0)); + #else UNUSED(x); UNUSED(y); @@ -2149,6 +2271,35 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v *s = hsum_float_8(accum); +#elif defined(__loongarch_sx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + + __m128 accum = (__m128)__lsx_vldi(0); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi = __lsx_vldi(0); + for (int ib = 0; ib < QK_K/32; ++ib) { + const __m128i q4bits = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q8b_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8b_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q4b_0 = __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits, 0xf)); + const __m128i q4b_1 = __lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits, 4)); + const __m128i p16_0 = lsx_maddubs_h(q4b_0, q8b_0); + const __m128i p16_1 = lsx_maddubs_h(q4b_1, q8b_1); + const int16_t ls = (((x[ibl].scales_l[ib/2] >> ((ib & 1) * 4)) & 0xf) | ((sh & 0x3) << 4)) - 32; + sh >>= 2; + sumi = __lsx_vadd_w(lsx_madd_h(p16_0, __lsx_vreplgr2vr_h(ls)), sumi); + sumi = __lsx_vadd_w(lsx_madd_h(p16_1, __lsx_vreplgr2vr_h(ls)), sumi); + } + const float ds = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + accum = __lsx_vfadd_s(__lsx_vfmul_s(__lsx_vreplfr2vr_s(ds), __lsx_vffint_s_w(sumi)), accum); + } + + *s = ((v4f32)lsx_hadd_s(lsx_hadd_s(accum, accum), lsx_hadd_s(accum, accum)))[0]; + #else UNUSED(x); UNUSED(y); diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ee69e5ab5e5..47e9180bf9b 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -123,7 +123,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in assert(k % QK_K == 0); size_t nb = k / QK_K; -#if defined __riscv_v_intrinsic +#if defined __riscv_v block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); @@ -578,7 +578,8 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q2_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -590,8 +591,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; uint8_t atmp[16]; @@ -686,246 +685,281 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q2_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; uint8_t atmp[16]; - const int vector_length = __riscv_vlenb() * 8; uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "lb zero, 15(%[sc])\n\t" + "vle8.v v1, (%[sc])\n\t" + "vle8.v v2, (%[bsums])\n\t" + "addi %[tmp], %[bsums], 16\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vle8.v v3, (%[tmp])\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "lb zero, 31(%[q2])\n\t" + "addi %[tmp], %[q2], 16\n\t" + "addi %[t1], %[q8], 16\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vle8.v v0, (%[q2])\n\t" + "vle8.v v1, (%[tmp])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v3, v1, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "addi %[tmp], %[q8], 32\n\t" + "vle8.v v8, (%[q8])\n\t" + "vle8.v v9, (%[t1])\n\t" + "addi %[t1], %[t1], 32\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vsrl.vi v7, v1, 6\n\t" + "vle8.v v10, (%[tmp])\n\t" + "vle8.v v11, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v1, v1, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vle8.v v12, (%[tmp])\n\t" + "vle8.v v13, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v3, v3, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vand.vi v5, v5, 0x3\n\t" + "vle8.v v14, (%[tmp])\n\t" + "vle8.v v15, (%[t1])\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v18, v1, v9\n\t" + "vwmul.vv v20, v2, v10\n\t" + "vwmul.vv v22, v3, v11\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vwmul.vv v26, v5, v13\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v30, v7, v15\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vmv.v.x v0, zero\n\t" + "lbu %[tmp], 0(%[scale])\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "lbu %[t1], 1(%[scale])\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "lbu %[t2], 2(%[scale])\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "lbu %[t3], 3(%[scale])\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "lbu %[t4], 4(%[scale])\n\t" + "vwredsum.vs v8, v17, v8\n\t" + "vwredsum.vs v9, v19, v9\n\t" + "lbu %[t5], 5(%[scale])\n\t" + "vwredsum.vs v10, v21, v10\n\t" + "vwredsum.vs v11, v23, v11\n\t" + "lbu %[t6], 6(%[scale])\n\t" + "vwredsum.vs v12, v25, v12\n\t" + "vwredsum.vs v13, v27, v13\n\t" + "lbu %[t7], 7(%[scale])\n\t" + "vwredsum.vs v14, v29, v14\n\t" + "vwredsum.vs v15, v31, v15\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + + sumf += dall * isum; + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - size_t vl = 16; + const int nb = n / QK_K; - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + float sumf = 0; + uint8_t atmp[16]; - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - vl = 32; + size_t vl = 16; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - uint8_t is = 0; - int isum = 0; + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - for (int j = 0; j < QK_K / 128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + vl = 32; - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + uint8_t is = 0; + int isum = 0; - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); - q2 += 32; - q8 += 128; - is = 8; - } + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - sumf += dall * isum; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - uint8_t *patmp = atmp; - int vsums; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 16, e8, m1\n\t" - "vmv.v.x v8, zero\n\t" - "lb zero, 15(%[sc])\n\t" - "vle8.v v1, (%[sc])\n\t" - "vle8.v v2, (%[bsums])\n\t" - "addi %[tmp], %[bsums], 16\n\t" - "vand.vi v0, v1, 0xF\n\t" - "vsrl.vi v1, v1, 4\n\t" - "vle8.v v3, (%[tmp])\n\t" - "vse8.v v0, (%[scale])\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vzext.vf2 v0, v1\n\t" - "vwmul.vv v4, v0, v2\n\t" - "vsetivli zero, 16, e32, m4\n\t" - "vredsum.vs v8, v4, v8\n\t" - "vmv.x.s %[vsums], v8" - : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) - : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - sumf += dmin * vsums; - int isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "lb zero, 31(%[q2])\n\t" - "addi %[tmp], %[q2], 16\n\t" - "addi %[t1], %[q8], 16\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vle8.v v0, (%[q2])\n\t" - "vle8.v v1, (%[tmp])\n\t" - "vsrl.vi v2, v0, 2\n\t" - "vsrl.vi v3, v1, 2\n\t" - "vsrl.vi v4, v0, 4\n\t" - "addi %[tmp], %[q8], 32\n\t" - "vle8.v v8, (%[q8])\n\t" - "vle8.v v9, (%[t1])\n\t" - "addi %[t1], %[t1], 32\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vsrl.vi v7, v1, 6\n\t" - "vle8.v v10, (%[tmp])\n\t" - "vle8.v v11, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v1, v1, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vle8.v v12, (%[tmp])\n\t" - "vle8.v v13, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v3, v3, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vand.vi v5, v5, 0x3\n\t" - "vle8.v v14, (%[tmp])\n\t" - "vle8.v v15, (%[t1])\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v18, v1, v9\n\t" - "vwmul.vv v20, v2, v10\n\t" - "vwmul.vv v22, v3, v11\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vwmul.vv v26, v5, v13\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v30, v7, v15\n\t" - "vsetivli zero, 8, e16, m1\n\t" - "vmv.v.x v0, zero\n\t" - "lbu %[tmp], 0(%[scale])\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "lbu %[t1], 1(%[scale])\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "lbu %[t2], 2(%[scale])\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "lbu %[t3], 3(%[scale])\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "lbu %[t4], 4(%[scale])\n\t" - "vwredsum.vs v8, v17, v8\n\t" - "vwredsum.vs v9, v19, v9\n\t" - "lbu %[t5], 5(%[scale])\n\t" - "vwredsum.vs v10, v21, v10\n\t" - "vwredsum.vs v11, v23, v11\n\t" - "lbu %[t6], 6(%[scale])\n\t" - "vwredsum.vs v12, v25, v12\n\t" - "vwredsum.vs v13, v27, v13\n\t" - "lbu %[t7], 7(%[scale])\n\t" - "vwredsum.vs v14, v29, v14\n\t" - "vwredsum.vs v15, v31, v15\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [isum] "+&r" (isum) - : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q2 += 32; q8 += 128; patmp += 8; - } + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); - sumf += dall * isum; + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2 += 32; + q8 += 128; + is = 8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += dall * isum; } *s = sumf; +} +#endif +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q2_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q2_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q2_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q3_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -941,8 +975,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - uint32_t utmp[4]; float sumf = 0; @@ -1068,257 +1100,274 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q3_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; uint32_t utmp[4]; float sumf = 0; uint32_t aux[3]; - const int vector_length = __riscv_vlenb() * 8; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "lb zero, 31(%[q3])\n\t" + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "lb zero, 64(%[q8])\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "lb zero, 127(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "lb %[tmp], 0(%[scale])\n\t" + "lb %[t1], 1(%[scale])\n\t" + "lb %[t2], 2(%[scale])\n\t" + "lb %[t3], 3(%[scale])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "lb %[t4], 4(%[scale])\n\t" + "lb %[t5], 5(%[scale])\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "lb %[t6], 6(%[scale])\n\t" + "lb %[t7], 7(%[scale])\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + *s = sumf; +} - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); +void ggml_vec_dot_q3_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - size_t vl = 32; - uint8_t m = 1; + const int nb = n / QK_K; + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - int sum_t = 0; + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - for (int j = 0; j < QK_K; j += 128) { + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; - vl = 32; - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + size_t vl = 32; + uint8_t m = 1; - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; + int sum_t = 0; - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; + for (int j = 0; j < QK_K; j += 128) { - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; + vl = 32; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - vl = 16; + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - q3 += 32; q8 += 128; scale += 8; + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - } + vl = 16; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - sumf += d*sum_t; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - int8_t * scale = (int8_t *)utmp; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v0, (%[s6b])\n\t" - "vmv1r.v v2, v0\n\t" - "vsetivli zero, 2, e64, m1\n\t" - "vmv.v.x v9, %[sh]\n\t"\ - "vslidedown.vi v1, v0, 1\n\t" - "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} - "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} - "vsetivli zero, 4, e32, m1\n\t" - "vid.v v9\n\t" - "vmv.x.s %[tmp], v1\n\t" - "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} - "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} - "vsrl.vv v4, v1, v9\n\t" - "vsrl.vv v2, v0, v8\n\t" - "vand.vx v5, v4, %[kmask1]\n\t" - "vand.vx v3, v2, %[kmask2]\n\t" - "vsll.vi v6, v5, 4\n\t" - "vor.vv v7, v6, v3\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vsub.vx v0, v7, %[c]\n\t" - "vse8.v v0, (%[scale])" - : [tmp] "=&r" (tmp) - : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) - , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - uint8_t m = 1; - int isum = 0; - for (int j = 0; j < QK_K; j += 128) { - __asm__ __volatile__( - "lb zero, 31(%[q3])\n\t" - "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" - "vle8.v v8, (%[q3])\n\t" - "vsrl.vi v10, v8, 2\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v8, 6\n\t" - "lb zero, 64(%[q8])\n\t" - "vand.vi v8, v8, 3\n\t" - "vand.vi v10, v10, 3\n\t" - "vand.vi v12, v12, 3\n\t" - "vle8.v v2, (%[qh])\n\t" - "lb zero, 127(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v8, v8, -4, v0.t\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v10, v10, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v12, v12, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v14, v14, -4, v0.t\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "lb %[tmp], 0(%[scale])\n\t" - "lb %[t1], 1(%[scale])\n\t" - "lb %[t2], 2(%[scale])\n\t" - "lb %[t3], 3(%[scale])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "lb %[t4], 4(%[scale])\n\t" - "lb %[t5], 5(%[scale])\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "lb %[t6], 6(%[scale])\n\t" - "lb %[t7], 7(%[scale])\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [m] "+&r" (m), [isum] "+&r" (isum) - : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) - , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q3 += 32; q8 += 128; scale += 8; - } + q3 += 32; q8 += 128; scale += 8; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * isum; } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - *s = sumf; - -#else + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(x); - UNUSED(y); - UNUSED(nb); + sumf += d*sum_t; - ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif + } + *s = sumf; } -void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_q3_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1326,27 +1375,289 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi UNUSED(by); UNUSED(bs); - const block_q4_K * GGML_RESTRICT x = vx; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; const block_q8_K * GGML_RESTRICT y = vy; const int nb = n / QK_K; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); uint32_t utmp[4]; - -#if defined __riscv_xtheadvector - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - float sumf = 0; + uint32_t aux[3]; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(qh, vl); + + int sum_t = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x03, vl)); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf2_t qh_m0 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_0 = __riscv_vmseq_vx_u8mf2_b16(qh_m0, 0, vl); + vint8mf2_t q3_m0 = __riscv_vsub_vx_i8mf2_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m1 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_1 = __riscv_vmseq_vx_u8mf2_b16(qh_m1, 0, vl); + vint8mf2_t q3_m1 = __riscv_vsub_vx_i8mf2_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m2 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_2 = __riscv_vmseq_vx_u8mf2_b16(qh_m2, 0, vl); + vint8mf2_t q3_m2 = __riscv_vsub_vx_i8mf2_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m3 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_3 = __riscv_vmseq_vx_u8mf2_b16(qh_m3, 0, vl); + vint8mf2_t q3_m3 = __riscv_vsub_vx_i8mf2_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(q3_m0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(q3_m1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(q3_m2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(q3_m3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + } + + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf4_t vqh = __riscv_vle8_v_u8mf4(qh, vl); + + int sum_t = 0; + + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf4_t q3_x = __riscv_vle8_v_u8mf4(q3, vl); + + vint8mf4_t q3_0 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(q3_x, 0x03, vl)); + vint8mf4_t q3_1 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf4_t q3_2 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf4_t q3_3 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf4_t qh_m0 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_0 = __riscv_vmseq_vx_u8mf4_b32(qh_m0, 0, vl); + vint8mf4_t q3_m0 = __riscv_vsub_vx_i8mf4_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m1 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_1 = __riscv_vmseq_vx_u8mf4_b32(qh_m1, 0, vl); + vint8mf4_t q3_m1 = __riscv_vsub_vx_i8mf4_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m2 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_2 = __riscv_vmseq_vx_u8mf4_b32(qh_m2, 0, vl); + vint8mf4_t q3_m2 = __riscv_vsub_vx_i8mf4_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m3 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_3 = __riscv_vmseq_vx_u8mf4_b32(qh_m3, 0, vl); + vint8mf4_t q3_m3 = __riscv_vsub_vx_i8mf4_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(q3_m0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(q3_m1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(q3_m2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(q3_m3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q3_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q3_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q3_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q3_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q3_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q4_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); int tmp, tmp2, sumi; __asm__ __volatile__( @@ -1452,277 +1763,317 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + + float ftmp, ft2; + const uint8_t * restrict q40; + const uint8_t * restrict q41; + const uint8_t * restrict q42; + const uint8_t * restrict q43; + const int8_t * restrict q80; + const int8_t * restrict q81; + const int8_t * restrict q82; + const int8_t * restrict q83; + int s0, s1, s2, s3; + + __asm__ __volatile__( + "li %[s1], 8\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vle32.v v1, (%[s6b])\n\t" + "vslide1down.vx v1, v1, zero\n\t" + "vmv.v.x v16, zero\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1, ta, ma\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vsse32.v v8, (%[utmp]), %[s1]\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "addi %[s0], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v1, (%[s0]), %[s1]\n\t" + "vsetivli zero, 8, e16, m1, ta, ma\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vredsum.vs v0, v6, v16\n\t" + "vredsum.vs v0, v7, v0\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "vsetivli zero, 16, e8, m1, ta, ma\n\t" + "vle8.v v0, (%[xs])\n\t" + "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" + "addi %[q40], %[xs], 64\n\t" + "addi %[q41], %[xs], 16\n\t" + "addi %[q42], %[xs], 32\n\t" + "addi %[q43], %[xs], 48\n\t" + "addi %[q80], %[ys], 64\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "addi %[q81], %[ys], 16\n\t" + "addi %[q41], %[q41], 64\n\t" + "addi %[q82], %[ys], 32\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[ys])\n\t" + "addi %[q42], %[q42], 64\n\t" + "addi %[q83], %[ys], 48\n\t" + "addi %[q43], %[q43], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v5, v1, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q80])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vsrl.vi v6, v2, 4\n\t" + "addi %[q80], %[q80], 64\n\t" + "vle8.v v13, (%[q81])\n\t" + "vle8.v v14, (%[q82])\n\t" + "vand.vi v2, v2, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v7, v3, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vle8.v v15, (%[q83])\n\t" + "vle8.v v0, (%[q40])\n\t" + "vand.vi v3, v3, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmacc.vv v16, v1, v9\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "vwmacc.vv v24, v3, v13\n\t" + "vwmacc.vv v20, v5, v11\n\t" + "vwmacc.vv v28, v7, v15\n\t" + "addi %[q40], %[q80], 64\n\t" + "addi %[q41], %[q81], 64\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[q80])\n\t" + "addi %[q42], %[q82], 64\n\t" + "addi %[q43], %[q83], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v7, v3, 4\n\t" + "vand.vi v3, v3, 0xF\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q40])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "vsrl.vi v6, v2, 4\n\t" + "vand.vi v2, v2, 0xF\n\t" + "vwmul.vv v18, v0, v8\n\t" + "vle8.v v13, (%[q41])\n\t" + "vle8.v v14, (%[q42])\n\t" + "vwmul.vv v26, v2, v12\n\t" + "vwmul.vv v22, v4, v10\n\t" + "vwmul.vv v30, v6, v14\n\t" + "vwmacc.vv v18, v1, v9\n\t" + "vle8.v v15, (%[q43])\n\t" + "vwmacc.vv v26, v3, v13\n\t" + "vwmacc.vv v22, v5, v11\n\t" + "vwmacc.vv v30, v7, v15\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 16, e16, m2, ta, ma\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "lbu %[s0], 0(%[scale])\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "lbu %[s1], 1(%[scale])\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "lbu %[s2], 2(%[scale])\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "lbu %[s3], 3(%[scale])\n\t" + "vwredsum.vs v8, v18, v0\n\t" + "lbu %[q40], 4(%[scale])\n\t" + "vwredsum.vs v9, v22, v0\n\t" + "lbu %[q41], 5(%[scale])\n\t" + "vwredsum.vs v10, v26, v0\n\t" + "lbu %[q42], 6(%[scale])\n\t" + "vwredsum.vs v11, v30, v0\n\t" + "lbu %[q43], 7(%[scale])\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vmul.vx v0, v4, %[s0]\n\t" + "vmul.vx v1, v8, %[q40]\n\t" + "vmacc.vx v0, %[s1], v5\n\t" + "vmacc.vx v1, %[q41], v9\n\t" + "vmacc.vx v0, %[s2], v6\n\t" + "vmacc.vx v1, %[q42], v10\n\t" + "vmacc.vx v0, %[s3], v7\n\t" + "vmacc.vx v1, %[q43], v11\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfcvt.f.x.v v1, v1\n\t" + "vfmv.f.s %[ft2], v0\n\t" + "vfmv.f.s %[ftmp], v1\n\t" + "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" + "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" + : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) + , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) + , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) + , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) + : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) + , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + } - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + *s = sumf; +} - size_t vl = 8; +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + const int nb = n / QK_K; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + uint32_t utmp[4]; - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + float sumf = 0; + for (int i = 0; i < nb; ++i) { + size_t vl = 8; - vl = 32; + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - int32_t sum_1 = 0; - int32_t sum_2 = 0; + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + vl = 32; - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + int32_t sum_1 = 0; + int32_t sum_2 = 0; - q4 += 32; q8 += 64; + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - } + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - sumf += d*(sum_1 + sum_2); + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - - float ftmp, ft2; - const uint8_t * restrict q40; - const uint8_t * restrict q41; - const uint8_t * restrict q42; - const uint8_t * restrict q43; - const int8_t * restrict q80; - const int8_t * restrict q81; - const int8_t * restrict q82; - const int8_t * restrict q83; - int s0, s1, s2, s3; + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; - __asm__ __volatile__( - "li %[s1], 8\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vle32.v v1, (%[s6b])\n\t" - "vslide1down.vx v1, v1, zero\n\t" - "vmv.v.x v16, zero\n\t" - "vslidedown.vi v2, v1, 2\n\t" - "vmv1r.v v3, v2\n\t" - "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} - "vsetivli zero, 2, e32, m1, ta, ma\n\t" - "vmv.v.i v4, 4\n\t" - "vand.vx v8, v1, %[kmask1]\n\t" - "vslide1up.vx v5, v4, zero\n\t" // {0, 4} - "vsrl.vi v6, v1, 6\n\t" - "vsrl.vv v7, v2, v5\n\t" - "vsse32.v v8, (%[utmp]), %[s1]\n\t" - "vand.vx v0, v6, %[kmask3]\n\t" - "vand.vx v2, v7, %[kmask2]\n\t" - "vsll.vi v6, v0, 4\n\t" - "addi %[s0], %[utmp], 4\n\t" - "vor.vv v1, v6, v2\n\t" - "vsse32.v v1, (%[s0]), %[s1]\n\t" - "vsetivli zero, 8, e16, m1, ta, ma\n\t" - "vle32.v v2, (%[bsums])\n\t" - "vnsrl.wi v0, v2, 0\n\t" - "vnsrl.wi v1, v2, 16\n\t" - "vadd.vv v2, v0, v1\n\t" - "vle8.v v3, (%[mins])\n\t" - "vzext.vf2 v4, v3\n\t" - "vwmul.vv v6, v4, v2\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vredsum.vs v0, v6, v16\n\t" - "vredsum.vs v0, v7, v0\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "vsetivli zero, 16, e8, m1, ta, ma\n\t" - "vle8.v v0, (%[xs])\n\t" - "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" - "addi %[q40], %[xs], 64\n\t" - "addi %[q41], %[xs], 16\n\t" - "addi %[q42], %[xs], 32\n\t" - "addi %[q43], %[xs], 48\n\t" - "addi %[q80], %[ys], 64\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "addi %[q81], %[ys], 16\n\t" - "addi %[q41], %[q41], 64\n\t" - "addi %[q82], %[ys], 32\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[ys])\n\t" - "addi %[q42], %[q42], 64\n\t" - "addi %[q83], %[ys], 48\n\t" - "addi %[q43], %[q43], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v5, v1, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q80])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vsrl.vi v6, v2, 4\n\t" - "addi %[q80], %[q80], 64\n\t" - "vle8.v v13, (%[q81])\n\t" - "vle8.v v14, (%[q82])\n\t" - "vand.vi v2, v2, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v7, v3, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vle8.v v15, (%[q83])\n\t" - "vle8.v v0, (%[q40])\n\t" - "vand.vi v3, v3, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vwmul.vv v24, v2, v12\n\t" - "vwmul.vv v20, v4, v10\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmacc.vv v16, v1, v9\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "vwmacc.vv v24, v3, v13\n\t" - "vwmacc.vv v20, v5, v11\n\t" - "vwmacc.vv v28, v7, v15\n\t" - "addi %[q40], %[q80], 64\n\t" - "addi %[q41], %[q81], 64\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[q80])\n\t" - "addi %[q42], %[q82], 64\n\t" - "addi %[q43], %[q83], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v7, v3, 4\n\t" - "vand.vi v3, v3, 0xF\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q40])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "vsrl.vi v6, v2, 4\n\t" - "vand.vi v2, v2, 0xF\n\t" - "vwmul.vv v18, v0, v8\n\t" - "vle8.v v13, (%[q41])\n\t" - "vle8.v v14, (%[q42])\n\t" - "vwmul.vv v26, v2, v12\n\t" - "vwmul.vv v22, v4, v10\n\t" - "vwmul.vv v30, v6, v14\n\t" - "vwmacc.vv v18, v1, v9\n\t" - "vle8.v v15, (%[q43])\n\t" - "vwmacc.vv v26, v3, v13\n\t" - "vwmacc.vv v22, v5, v11\n\t" - "vwmacc.vv v30, v7, v15\n\t" - "vmv.v.x v0, zero\n\t" - "vsetivli zero, 16, e16, m2, ta, ma\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "lbu %[s0], 0(%[scale])\n\t" - "vwredsum.vs v5, v20, v0\n\t" - "lbu %[s1], 1(%[scale])\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "lbu %[s2], 2(%[scale])\n\t" - "vwredsum.vs v7, v28, v0\n\t" - "lbu %[s3], 3(%[scale])\n\t" - "vwredsum.vs v8, v18, v0\n\t" - "lbu %[q40], 4(%[scale])\n\t" - "vwredsum.vs v9, v22, v0\n\t" - "lbu %[q41], 5(%[scale])\n\t" - "vwredsum.vs v10, v26, v0\n\t" - "lbu %[q42], 6(%[scale])\n\t" - "vwredsum.vs v11, v30, v0\n\t" - "lbu %[q43], 7(%[scale])\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vmul.vx v0, v4, %[s0]\n\t" - "vmul.vx v1, v8, %[q40]\n\t" - "vmacc.vx v0, %[s1], v5\n\t" - "vmacc.vx v1, %[q41], v9\n\t" - "vmacc.vx v0, %[s2], v6\n\t" - "vmacc.vx v1, %[q42], v10\n\t" - "vmacc.vx v0, %[s3], v7\n\t" - "vmacc.vx v1, %[q43], v11\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfcvt.f.x.v v1, v1\n\t" - "vfmv.f.s %[ft2], v0\n\t" - "vfmv.f.s %[ftmp], v1\n\t" - "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" - "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" - : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) - , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) - , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) - , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) - : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) - , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) - , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) - , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += d*(sum_1 + sum_2); + } *s = sumf; +} +#endif +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q4_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q4_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 256 and above + ggml_vec_dot_q4_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(nb); - UNUSED(utmp); - ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -1823,7 +2174,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; - } sums += aux32 * d; @@ -1846,7 +2196,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q6_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1859,8 +2210,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -1939,224 +2288,462 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(&x[i + 1].d, 0, 1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int q6h; + float ftmp; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "addi %[q6h], %[q6], 32\n\t" + "ld t0, 0(%[scale])\n\t" + "addi %[scale], %[scale], 8\n\t" + "slli t6, t0, 1 * 8\n\t" + "lb zero, 0(%[q6])\n\t" + "slli t5, t0, 2 * 8\n\t" + "slli t4, t0, 3 * 8\n\t" + "lb zero, 0(%[q6h])\n\t" + "slli t3, t0, 4 * 8\n\t" + "slli t2, t0, 5 * 8\n\t" + "lb zero, 0(%[qh])\n\t" + "lb zero, 31(%[q6h])\n\t" + "slli t1, t0, 6 * 8\n\t" + "srai a7, t0, 56\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v8, (%[q6])\n\t" + "srai t6, t6, 56\n\t" + "srai t5, t5, 56\n\t" + "srai t4, t4, 56\n\t" + "srai t3, t3, 56\n\t" + "vle8.v v10, (%[q6h])\n\t" + "addi %[q6], %[q6], 64\n\t" + "slli t0, t0, 7 * 8\n\t" + "srai t2, t2, 56\n\t" + "srai t1, t1, 56\n\t" + "srai t0, t0, 56\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v10, 4\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vand.vi v10, v10, 0xF\n\t" + "lb zero, 32(%[q8])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "lb zero, 64(%[q8])\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "lb zero, 96(%[q8])\n\t" + "vand.vx v2, v2, %[mask]\n\t" + "vand.vx v4, v4, %[mask]\n\t" + "vand.vx v6, v6, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "lb zero, 127(%[q8])\n\t" + "vor.vv v10, v10, v2\n\t" + "vor.vv v12, v12, v4\n\t" + "vor.vv v14, v14, v6\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v10, t0\n\t" + "vmul.vx v1, v9, t1\n\t" + "vmacc.vx v0, t2, v8\n\t" + "vmacc.vx v1, t3, v7\n\t" + "vmacc.vx v0, t4, v11\n\t" + "vmacc.vx v1, t5, v12\n\t" + "vmacc.vx v0, t6, v13\n\t" + "vmacc.vx v1, a7, v14\n\t" + "vadd.vv v0, v0, v1\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" + : [q6] "+&r" (q6), [q6h] "=&r" (q6h) + , [scale] "+&r" (scale) + , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) + : [qh] "r" (qh), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30), [d] "f" (d) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" + , "a6", "a5", "a4", "a3" + ); + qh += 32; q8 += 128; + } + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * GGML_RESTRICT scale = x[i].scales; - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + size_t vl; - const int8_t * GGML_RESTRICT scale = x[i].scales; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - size_t vl; + int sum_t = 0; + int is = 0; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - int sum_t = 0; - int is = 0; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+32, vl); + + vuint8mf2_t q6a_0 = __riscv_vand_vx_u8mf2(q6_0, 0x0F, vl); + vuint8mf2_t q6a_1 = __riscv_vand_vx_u8mf2(q6_1, 0x0F, vl); + vuint8mf2_t q6s_0 = __riscv_vsrl_vx_u8mf2(q6_0, 0x04, vl); + vuint8mf2_t q6s_1 = __riscv_vsrl_vx_u8mf2(q6_1, 0x04, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(qh_x, 0x03, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf2_t qhi_0 = __riscv_vor_vv_u8mf2(q6a_0, __riscv_vsll_vx_u8mf2(qh_0, 0x04, vl), vl); + vuint8mf2_t qhi_1 = __riscv_vor_vv_u8mf2(q6a_1, __riscv_vsll_vx_u8mf2(qh_1, 0x04, vl), vl); + vuint8mf2_t qhi_2 = __riscv_vor_vv_u8mf2(q6s_0, __riscv_vsll_vx_u8mf2(qh_2, 0x04, vl), vl); + vuint8mf2_t qhi_3 = __riscv_vor_vv_u8mf2(q6s_1, __riscv_vsll_vx_u8mf2(qh_3, 0x04, vl), vl); + + vint8mf2_t a_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_0), 32, vl); + vint8mf2_t a_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_1), 32, vl); + vint8mf2_t a_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_2), 32, vl); + vint8mf2_t a_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_3), 32, vl); + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(a_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(a_1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(a_2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(a_3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); + + q6 += 64; qh += 32; q8 += 128; is=8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + sumf += d * sum_t; + + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - for (int j = 0; j < QK_K/128; ++j) { + const int nb = n / QK_K; - vl = 32; + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + float sumf = 0; - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + const int8_t * GGML_RESTRICT scale = x[i].scales; - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + size_t vl = 32; - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + int sum_t = 0; + int is = 0; - vl = 16; + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf4_t qh_x = __riscv_vle8_v_u8mf4(qh, vl); + + // load Q6 + vuint8mf4_t q6_0 = __riscv_vle8_v_u8mf4(q6, vl); + vuint8mf4_t q6_1 = __riscv_vle8_v_u8mf4(q6+32, vl); + + vuint8mf4_t q6a_0 = __riscv_vand_vx_u8mf4(q6_0, 0x0F, vl); + vuint8mf4_t q6a_1 = __riscv_vand_vx_u8mf4(q6_1, 0x0F, vl); + vuint8mf4_t q6s_0 = __riscv_vsrl_vx_u8mf4(q6_0, 0x04, vl); + vuint8mf4_t q6s_1 = __riscv_vsrl_vx_u8mf4(q6_1, 0x04, vl); + + vuint8mf4_t qh_0 = __riscv_vand_vx_u8mf4(qh_x, 0x03, vl); + vuint8mf4_t qh_1 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf4_t qh_2 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf4_t qh_3 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf4_t qhi_0 = __riscv_vor_vv_u8mf4(q6a_0, __riscv_vsll_vx_u8mf4(qh_0, 0x04, vl), vl); + vuint8mf4_t qhi_1 = __riscv_vor_vv_u8mf4(q6a_1, __riscv_vsll_vx_u8mf4(qh_1, 0x04, vl), vl); + vuint8mf4_t qhi_2 = __riscv_vor_vv_u8mf4(q6s_0, __riscv_vsll_vx_u8mf4(qh_2, 0x04, vl), vl); + vuint8mf4_t qhi_3 = __riscv_vor_vv_u8mf4(q6s_1, __riscv_vsll_vx_u8mf4(qh_3, 0x04, vl), vl); + + vint8mf4_t a_0 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_0), 32, vl); + vint8mf4_t a_1 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_1), 32, vl); + vint8mf4_t a_2 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_2), 32, vl); + vint8mf4_t a_3 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_3), 32, vl); + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(a_0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(a_1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(a_2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(a_3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + q6 += 64; qh += 32; q8 += 128; is=8; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + } - q6 += 64; qh += 32; q8 += 128; is=8; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); - } + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); - sumf += d * sum_t; + sumf += d * sum_t; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - - __builtin_prefetch(&x[i + 1].d, 0, 1); - - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int q6h; - float ftmp; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "addi %[q6h], %[q6], 32\n\t" - "ld t0, 0(%[scale])\n\t" - "addi %[scale], %[scale], 8\n\t" - "slli t6, t0, 1 * 8\n\t" - "lb zero, 0(%[q6])\n\t" - "slli t5, t0, 2 * 8\n\t" - "slli t4, t0, 3 * 8\n\t" - "lb zero, 0(%[q6h])\n\t" - "slli t3, t0, 4 * 8\n\t" - "slli t2, t0, 5 * 8\n\t" - "lb zero, 0(%[qh])\n\t" - "lb zero, 31(%[q6h])\n\t" - "slli t1, t0, 6 * 8\n\t" - "srai a7, t0, 56\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vle8.v v8, (%[q6])\n\t" - "srai t6, t6, 56\n\t" - "srai t5, t5, 56\n\t" - "srai t4, t4, 56\n\t" - "srai t3, t3, 56\n\t" - "vle8.v v10, (%[q6h])\n\t" - "addi %[q6], %[q6], 64\n\t" - "slli t0, t0, 7 * 8\n\t" - "srai t2, t2, 56\n\t" - "srai t1, t1, 56\n\t" - "srai t0, t0, 56\n\t" - "vle8.v v4, (%[qh])\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v10, 4\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vi v8, v8, 0xF\n\t" - "vand.vi v10, v10, 0xF\n\t" - "lb zero, 32(%[q8])\n\t" - "vsll.vi v0, v4, 4\n\t" - "vsll.vi v2, v4, 2\n\t" - "lb zero, 64(%[q8])\n\t" - "vsrl.vi v6, v4, 2\n\t" - "vand.vx v0, v0, %[mask]\n\t" - "lb zero, 96(%[q8])\n\t" - "vand.vx v2, v2, %[mask]\n\t" - "vand.vx v4, v4, %[mask]\n\t" - "vand.vx v6, v6, %[mask]\n\t" - "vor.vv v8, v8, v0\n\t" - "lb zero, 127(%[q8])\n\t" - "vor.vv v10, v10, v2\n\t" - "vor.vv v12, v12, v4\n\t" - "vor.vv v14, v14, v6\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "vsub.vx v8, v8, %[vl32]\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v10, t0\n\t" - "vmul.vx v1, v9, t1\n\t" - "vmacc.vx v0, t2, v8\n\t" - "vmacc.vx v1, t3, v7\n\t" - "vmacc.vx v0, t4, v11\n\t" - "vmacc.vx v1, t5, v12\n\t" - "vmacc.vx v0, t6, v13\n\t" - "vmacc.vx v1, a7, v14\n\t" - "vadd.vv v0, v0, v1\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" - : [q6] "+&r" (q6), [q6h] "=&r" (q6h) - , [scale] "+&r" (scale) - , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) - : [qh] "r" (qh), [q8] "r" (q8) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - , [mask] "r" (0x30), [d] "f" (d) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" - , "a6", "a5", "a4", "a3" - ); - qh += 32; q8 += 128; - } - } - break; - default: - assert(false && "Unsupported vector length"); - break; } *s = sumf; +} +#endif +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q6_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q6_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q6_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q6_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q6_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2364,10 +2951,190 @@ static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16m1_t index = __riscv_vsll_vx_u16m1(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, index, 32)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 256); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 256); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 4), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 5), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 6), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 7), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // Mask for processing 32 elements per lsum register. + vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16mf2_t index = __riscv_vlmul_trunc_v_u16m1_u16mf2(__riscv_vsll_vx_u16m1(qh_index, 3, 32)); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m2_t grid0 = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vluxei16_v_i64m2((const int64_t*)iq1s_grid, index, 32)); + vint8m2_t q80 = __riscv_vle8_v_i8m2(y[i].qs, 256); + vint16m4_t lsum0 = __riscv_vwmul_vv_i16m4(grid0, q80, 256); + + // Reduce. + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 64)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 64)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 64)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 64)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2375,6 +3142,12 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2384,7 +3157,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2664,10 +3437,287 @@ static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(32); + const vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 32); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 32); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 16); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 16); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 16); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 32), 32); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 16)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16m1(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 16)), 32), 0, 32); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 32); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 256); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_1, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_1, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_2, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_3, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_2, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_3, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_4, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_5, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_4, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_5, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_6, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_7, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_6, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_7, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_8, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_9, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_8, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_9, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_10, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_11, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_10, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_11, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_12, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_13, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_12, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_13, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_14, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_15, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_14, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_15, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 32); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 32)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 32)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 64); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 64); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf8_t qh_8 = __riscv_vle8_v_u8mf8(qh, 16); + const vuint16mf4_t qh_16_lo = __riscv_vzext_vf2_u16mf4(qh_8, 16); + const vuint16mf4_t qh_16_hi = __riscv_vsll_vx_u16mf4(qh_16_lo, 8, 16); + const vuint16mf2_t qhb = __riscv_vzext_vf2_u16mf2( + __riscv_vreinterpret_v_u16mf4_u8mf4(__riscv_vor_vv_u16mf4(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16mf2_t qsb = __riscv_vzext_vf2_u16mf2(__riscv_vle8_v_u8mf4(&qs[0], 32), 32); + const vuint16mf2_t shift = __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00040008, 16)); + vuint16mf2_t index = __riscv_vor_vv_u16mf2(qsb, __riscv_vand_vx_u16mf2(__riscv_vsll_vv_u16mf2(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16mf2(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m2_t iq1b = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vreinterpret_v_u64m2_i64m2( + __riscv_vluxei16_v_u64m2(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool32_t mask = __riscv_vmsgtu_vx_u16mf2_b32( + __riscv_vand_vv_u16mf2(qhb, __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00800008, 16)), 32), 0, 32); + const vint64m2_t delta_pos = __riscv_vmv_v_x_i64m2(0x0101010101010101, 32); + const vint8m2_t delta = __riscv_vreinterpret_v_i64m2_i8m2( + __riscv_vmerge_vxm_i64m2(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(iq1b, q8b, 256); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + + // Accumulate in acc1 and acc2 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_4, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_4, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_8, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_8, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_12, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_12, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 16); + // + const vbool16_t l_mask_16_32 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_1, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_1, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_5, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_5, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_9, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_9, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_13, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_13, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 32); + // + const vbool16_t l_mask_32_48 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_2, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_2, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_6, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_6, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_10, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_10, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_14, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_14, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 48); + // + const vbool16_t l_mask_48_64 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 47, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_3, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_3, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_7, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_7, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_11, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_11, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_15, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_15, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 64); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 64)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 64)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2675,6 +3725,12 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_m_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_m_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2684,7 +3740,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const uint8_t sign_gather_indices_arr[64] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 @@ -2887,10 +3943,275 @@ static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = 0.125f * sumf; } + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + vuint8m2_t v_ids = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 128); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 128); + + uint16_t gather_qh_arr[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 16); + + uint16_t shift_qh_arr[16] = {11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 16); + + // Masks for selecting lower/upper 16 lanes within a 32-lane i16m1 register + vuint16m1_t v_ids16 = __riscv_vid_v_u16m1(32); + vbool16_t m_hi16 = __riscv_vmsgeu_vx_u16m1_b16(v_ids16, 16, 32); + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 2; ++ib) { + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 16); + qs += 16; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 4); + qh += 4; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 4); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 16); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 16); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 16); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 16); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 16); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 16); + signs_ptr += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 128); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m1_t v0 = __riscv_vget_v_i16m4_i16m1(v_dot, 0); + vint16m1_t v1 = __riscv_vget_v_i16m4_i16m1(v_dot, 1); + vint16m1_t v2 = __riscv_vget_v_i16m4_i16m1(v_dot, 2); + vint16m1_t v3 = __riscv_vget_v_i16m4_i16m1(v_dot, 3); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v0, v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v0, v_zero, 32)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v1, v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v1, v_zero, 32)); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v2, v_zero, 16)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v2, v_zero, 32)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( v3, v_zero, 16)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v3, v_zero, 32)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + sum_block += s4 * (2 * (sc2 & 0xF) + 1); + sum_block += s5 * (2 * (sc2 >> 4) + 1); + sum_block += s6 * (2 * (sc3 & 0xF) + 1); + sum_block += s7 * (2 * (sc3 >> 4) + 1); + } + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + vuint8m2_t v_ids = __riscv_vid_v_u8m2(256); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 256); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 256); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 256); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 256); + + uint16_t gather_qh_arr[32] = { + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7 + }; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 32); + + uint16_t shift_qh_arr[32] = { + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5 + }; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 32); + + // Masks for 4 groups of 16 lanes within a 64-lane i16m4 chunk + vuint16m4_t v_ids64 = __riscv_vid_v_u16m4(64); + vbool4_t m_g0 = __riscv_vmsltu_vx_u16m4_b4(v_ids64, 16, 64); + vbool4_t m_g1 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 16, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 32, 64), 64); + vbool4_t m_g2 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 32, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 48, 64), 64); + vbool4_t m_g3 = __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 48, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 32); + qs += 32; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 8); + qh += 8; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 8); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 32); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 32); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 32); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 32); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 32); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + //loading signs + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs_ptr, 32); + signs_ptr += 32; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 256); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 256); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 256); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + q8 += 256; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 256); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 256); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint16m4_t c = v_dot; + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s8 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s9 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s10 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s11 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s12 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s13 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s14 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s15 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + int32_t sums_arr[16] = { s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15 }; + + // Load 8 scale bytes and split into 16 nibbles + vuint8mf2_t v_sc8 = __riscv_vle8_v_u8mf2(scales, 8); + scales += 8; + + vuint8mf2_t v_lo8 = __riscv_vand_vx_u8mf2(v_sc8, 0x0F, 8); + vuint8mf2_t v_hi8 = __riscv_vsrl_vx_u8mf2(v_sc8, 4, 8); + + vuint8m1_t v_idx16 = __riscv_vid_v_u8m1(16); + vuint8m1_t v_half = __riscv_vsrl_vx_u8m1(v_idx16, 1, 16); + vbool8_t m_even = __riscv_vmseq_vx_u8m1_b8(__riscv_vand_vx_u8m1(v_idx16, 1, 16), 0, 16); + + vuint8m1_t v_lo_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_lo8); + vuint8m1_t v_hi_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_hi8); + vuint8m1_t v_lo_g = __riscv_vrgather_vv_u8m1(v_lo_ext, v_half, 16); + vuint8m1_t v_hi_g = __riscv_vrgather_vv_u8m1(v_hi_ext, v_half, 16); + vuint8m1_t v_nib = __riscv_vmerge_vvm_u8m1(v_lo_g, v_hi_g, m_even, 16); + + static const uint8_t iq2s_scale_lut_16_local[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + }; + vuint8m1_t v_lut = __riscv_vle8_v_u8m1(iq2s_scale_lut_16_local, 16); + vuint8m1_t v_sc8v = __riscv_vrgather_vv_u8m1(v_lut, v_nib, 16); + + vint32m4_t v_sums = __riscv_vle32_v_i32m4(sums_arr, 16); + vuint16m2_t v_sc16 = __riscv_vwcvtu_x_x_v_u16m2(v_sc8v, 16); + vuint32m4_t v_sc32u = __riscv_vwcvtu_x_x_v_u32m4(v_sc16, 16); + vint32m4_t v_sc32 = __riscv_vreinterpret_v_u32m4_i32m4(v_sc32u); + vint32m4_t v_prod = __riscv_vmul_vv_i32m4(v_sums, v_sc32, 16); + + vint32m1_t v_zero32 = __riscv_vmv_v_x_i32m1(0, 1); + int32_t sum_part = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(v_prod, v_zero32, 16)); + sum_block += sum_part; + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} #endif void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2898,8 +4219,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq2_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; default: - ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq2_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -2907,7 +4231,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -3045,59 +4369,140 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT int32_t sum_int = 0; - // Loop over 4 subblocks of 64 elements (QK_K = 256) - for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { - // Load 8 uint16 indices (controls 64 values) - vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8); - qs += 8; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 16); + qs += 16; - // Extract indices for grid (low 9 bits) and signs (high 7 bits) - // Multiply by 8 (<< 3) for byte offsets into the uint64 tables - vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8); - vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8); + // Prepare offsets for grid and signs + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 16), 3, 16); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 16), 3, 16); - vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8); - vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8); + // Indexed load 128 weights (16 x 8-byte chunks) + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 16); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 16); - vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64)); - vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64)); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); - vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64); + // Apply signs to get dequantized IQ2 values + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 128); + asm volatile("" ::: "memory"); - vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; + // Load corresponding Q8 weights + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 128); + asm volatile("" ::: "memory"); - vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64); + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); - int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16)); - int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16)); - int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16)); - int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16)); + // 9. Reduce each 16-element chunk and apply corresponding nibble scale - const uint8_t scale_byte_1 = scales[0]; - const uint8_t scale_byte_2 = scales[1]; - scales += 2; + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), zero_vec, 16)); + sum_int += s0 * ((sc0 & 0x0F) * 2 + 1); - sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); - sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); - sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); - sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), zero_vec, 16)); + sum_int += s1 * ((sc0 >> 4) * 2 + 1); + + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), zero_vec, 16)); + sum_int += s2 * ((sc1 & 0x0F) * 2 + 1); + + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), zero_vec, 16)); + sum_int += s3 * ((sc1 >> 4) * 2 + 1); + + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), zero_vec, 16)); + sum_int += s4 * ((sc2 & 0x0F) * 2 + 1); + + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), zero_vec, 16)); + sum_int += s5 * ((sc2 >> 4) * 2 + 1); + + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), zero_vec, 16)); + sum_int += s6 * ((sc3 & 0x0F) * 2 + 1); + + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), zero_vec, 16)); + sum_int += s7 * ((sc3 >> 4) * 2 + 1); } - sumf += d * sum_int; + sumf += d * (float)sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // Load indices --- + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 32); + + // Extract low 9 bits and multiply by 8 (shift left 3) for byte offset into uint64 table + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 32), 3, 32); + + // Extract high 7 bits (shift right 9) and multiply by 8 (shift left 3) for byte offset + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 32); + + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q2_signed = __riscv_vmul_vv_i8m4(q2_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q2_signed, q8_all, 256); + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + +#pragma GCC unroll 1 + for (int j = 0; j < 8; ++j) { + uint8_t sc = scales[j]; + int16_t sc_lo = 2 * (sc & 0x0F) + 1; + int16_t sc_hi = 2 * (sc >> 4) + 1; + + vint32m1_t sum_v0 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32, 16), zero_vec, 16); + int32_t isum0 = __riscv_vmv_x_s_i32m1_i32(sum_v0); + + vint32m1_t sum_v1 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32 + 16, 16), zero_vec, 16); + int32_t isum1 = __riscv_vmv_x_s_i32m1_i32(sum_v1); + + sum += (float)isum0 * sc_lo + (float)isum1 * sc_hi; + } + + sumf += sum * combined_scale; } *s = 0.125f * sumf; } #endif void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3105,8 +4510,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq2_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3114,7 +4519,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3299,24 +4704,99 @@ static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.125f * sumf; } + +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + // Shift pattern {0,7,14,21} repeated 8 times for all 8 sub-blocks + uint8_t shift_arr[32] = { + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21 + }; + vuint8mf2_t v_shifts = __riscv_vle8_v_u8mf2(shift_arr, 32); + + // Gather pattern to broadcast the 8 sub-block scales across the 32 lookup slots + uint8_t gather_arr[32] = { + 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3, + 4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7 + }; + vuint8mf2_t v_sign_gather_idx = __riscv_vle8_v_u8mf2(gather_arr, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // De-interleave all 8 Index/Scale pairs for the 8x32-element sub-blocks + vuint32mf2x2_t tuple = __riscv_vlseg2e32_v_u32mf2x2((const uint32_t*)q2_ptr, 8); + vuint32mf2_t v_ind32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 0); + vuint32mf2_t v_sc32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 1); + + vuint8mf2_t v_raw_q2 = __riscv_vreinterpret_v_u32mf2_u8mf2(v_ind32); + vuint16m1_t vidx_q2 = __riscv_vwcvtu_x_x_v_u16m1(v_raw_q2, 32); + vidx_q2 = __riscv_vsll_vx_u16m1(vidx_q2, 3, 32); + + vuint32m2_t v_s = __riscv_vrgatherei16_vv_u32m2(__riscv_vlmul_ext_v_u32mf2_u32m2(v_sc32), __riscv_vwcvtu_x_x_v_u16m1(v_sign_gather_idx,32), 32); + v_s = __riscv_vsrl_vv_u32m2(v_s, __riscv_vwcvtu_x_x_v_u32m2(__riscv_vwcvtu_x_x_v_u16m1(v_shifts,32),32), 32); + v_s = __riscv_vand_vx_u32m2(v_s, 127, 32); + vuint16m1_t vidx_s2 = __riscv_vsll_vx_u16m1(__riscv_vncvt_x_x_w_u16m1(v_s, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_q2, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_s2, 32); + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q8s_all = __riscv_vmul_vv_i8m4(q8_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q8s_all, q2_all, 256); + + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + for (int j = 0; j < 8; ++j) { + uint32_t s_p = __riscv_vmv_x_s_u32mf2_u32(__riscv_vslidedown_vx_u32mf2(v_sc32, j, 8)); + int16_t sc = 2 * ((s_p >> 28) & 0xF) + 1; + dot_all=__riscv_vslidedown_vx_i16m8(dot_all,j*32,32); + vint32m1_t sum_v = __riscv_vwredsum_vs_i16m8_i32m1(dot_all, zero_vec, 32); + int32_t isum = __riscv_vmv_x_s_i32m1_i32(sum_v); + sum += (float)isum * sc; + } + + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} #endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: // 256 and above + case 256: ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + default: // 512 and above + ggml_vec_dot_iq2_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; } #else ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3506,19 +4986,108 @@ static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + // Generate Constants + vuint8mf2_t v_id_32 = __riscv_vid_v_u8mf2(32); + vuint8mf2_t v_qh_gather = __riscv_vsrl_vx_u8mf2(v_id_32, 3, 32); + vuint8mf2_t v_qh_shifts = __riscv_vand_vx_u8mf2(v_id_32, 7, 32); + vuint8m2_t v_id_128 = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather = __riscv_vsrl_vx_u8m2(v_id_128, 3, 128); // byte index + vuint8m2_t v_sign_shift_amts = __riscv_vand_vx_u8m2(v_id_128, 7, 128); // bit shift + vuint8m2_t v_one_128 = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_one_128, v_sign_shift_amts, 128); + vuint8m2_t v_scale_indices = __riscv_vsrl_vx_u8m2(v_id_128, 5, 128); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + for (int ib = 0; ib < 2; ++ib) { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 32); + qs += 32; + vuint8mf2_t v_qh_loaded = __riscv_vle8_v_u8mf2(qh, 4); + qh += 4; + vuint8mf2_t v_qh_expanded = __riscv_vrgather_vv_u8mf2(v_qh_loaded, v_qh_gather, 32); + v_qh_expanded = __riscv_vsrl_vv_u8mf2(v_qh_expanded, v_qh_shifts, 32); + v_qh_expanded = __riscv_vand_vx_u8mf2(v_qh_expanded, 1, 32); + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 32); // * 4 + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_expanded, 32); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 32); // * 256 * 4 + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 32); + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs, 16); + signs += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 128); + uint16_t sc_raw; + memcpy(&sc_raw, scales, 2); + scales += 2; // Advance 2 bytes + + uint8_t sc_unpacked[4]; + sc_unpacked[0] = (sc_raw & 0xF); + sc_unpacked[1] = (sc_raw >> 4) & 0xF; + sc_unpacked[2] = (sc_raw >> 8) & 0xF; + sc_unpacked[3] = (sc_raw >> 12) & 0xF; + + vuint8mf2_t v_sc_4 = __riscv_vle8_v_u8mf2(sc_unpacked, 4); + v_sc_4 = __riscv_vmul_vx_u8mf2(v_sc_4, 2, 4); + v_sc_4 = __riscv_vadd_vx_u8mf2(v_sc_4, 1, 4); + vuint8m2_t v_sc_4_expanded = __riscv_vlmul_ext_v_u8mf2_u8m2(v_sc_4); + vuint8m2_t v_scales_bcast = __riscv_vrgather_vv_u8m2(v_sc_4_expanded, v_scale_indices, 128); + vint16m4_t v_scales_i16 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwcvtu_x_x_v_u16m4(v_scales_bcast, 128)); + vint32m8_t v_weighted_sum = __riscv_vwmul_vv_i32m8(v_dot, v_scales_i16, 128); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + int32_t s_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m8_i32m1(v_weighted_sum, v_zero, 128)); + + sum_block += s_val; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} #endif void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: - ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; case 256: ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq3_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3526,7 +5095,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3712,10 +5281,181 @@ static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.25f * sumf; } + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // generate constants for unpacking metadata words into sign indices + vuint32m1_t v_shifts; + { + vuint32m1_t v_base = __riscv_vid_v_u32m1(16); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_base, 3, 16); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 16); + } + + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_idx = __riscv_vid_v_u16mf2(16); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_idx, 2, 16); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 32); + q3_indices += 32; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 32); + vuint32m2_t v_q3_mag_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 32); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_mag_u32)); + vuint32m1_t v_aux = __riscv_vreinterpret_v_u8m1_u32m1(__riscv_vle8_v_u8m1(metadata, 16)); + metadata += 4 * sizeof(uint32_t); + + vuint32m1_t v_aux_expanded = __riscv_vrgatherei16_vv_u32m1(v_aux, v_gather_idx, 16); + + vuint32m1_t v_s_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 16), 127, 16); + vuint16mf2_t sign_byte_offset = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_s_raw, 16), 3, 16); + vuint64m2_t v_s_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_byte_offset, 16); + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_s_u64)); + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_signs, 128); + vint16m4_t prod = __riscv_vwmul_vv_i16m4(v_q3_signed, v_q8, 128); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + int32_t group0_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 32)); + int32_t group1_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 32)); + int32_t group2_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 32)); + int32_t group3_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 32)); + + vuint32m1_t v_scales_raw = __riscv_vsrl_vx_u32m1(v_aux, 28, 4); + vuint32m1_t v_scales = __riscv_vadd_vx_u32m1( + __riscv_vsll_vx_u32m1(v_scales_raw, 1, 4), + 1, 4); + int32_t scale0 = (int32_t)__riscv_vmv_x_s_u32m1_u32(v_scales); + int32_t scale1 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 1, 4)); + int32_t scale2 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 2, 4)); + int32_t scale3 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 3, 4)); + + block_sum += (float)(group0_sum * scale0 + group1_sum * scale1 + + group2_sum * scale2 + group3_sum * scale3); + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + vuint32m1_t v_shifts; + { + vuint32m1_t v_id = __riscv_vid_v_u32m1(32); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_id, 3, 32); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 32); + } + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_id_16 = __riscv_vid_v_u16mf2(32); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_id_16, 2, 32); + } + + float sumf = 0.0f; + uint32_t aux32[8]; // Buffer for block metadata + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + vuint8mf2_t v_q3_idx_raw = __riscv_vle8_v_u8mf2(q3_indices, 64); + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_raw, 4, 64); + + vuint32m2_t v_q3_grid_vals = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 64); + + vint8m2_t v_q3_mags = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_grid_vals)); + + memcpy(aux32, metadata, 8 * sizeof(uint32_t)); + vuint32m1_t v_aux_8 = __riscv_vle32_v_u32m1(aux32, 8); + + vuint32m1_t v_aux_32 = __riscv_vrgatherei16_vv_u32m1(v_aux_8, v_gather_idx, 32); + + vuint32m1_t v_sign_idx_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_32, v_shifts, 32), 127, 32); + + vuint16mf2_t v_sign_offsets = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_sign_idx_raw, 32), 3, 32); + + vuint64m2_t v_signs_u64 = __riscv_vluxei16_v_u64m2(signs64, v_sign_offsets, 32); + + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_signs_u64)); + + vint8m2_t v_q3_final = __riscv_vmul_vv_i8m2(v_q3_mags, v_signs, 256); + + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_final, 256); + float block_sum = 0.0f; + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m4_t v_accum = v_dot; + + for (int j = 0; j < 8; ++j) { + float scale = (float)(2 * (aux32[j] >> 28) + 1); + + vint32m1_t v_partial_sum = __riscv_vwredsum_vs_i16m4_i32m1(v_accum, v_zero, 32); + + int32_t partial_sum_i = __riscv_vmv_x_s_i32m1_i32(v_partial_sum); + block_sum += partial_sum_i * scale; + v_accum = __riscv_vslidedown_vx_i16m4(v_accum, 32, 32); + + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} #endif void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3723,8 +5463,11 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const case 256: ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + case 512: + ggml_vec_dot_iq3_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 1024 and above + ggml_vec_dot_iq3_xxs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3732,7 +5475,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3847,7 +5590,7 @@ static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3861,7 +5604,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4007,10 +5750,205 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 128); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 128); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 32)); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 256); + q8 += 256; + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1(iq4_packed, 0xf, 128); + const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1(iq4_packed, 4, 128); + const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2(iq4bits_lo, iq4bits_hi); + const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2(__riscv_vrgatherei16_vv_u64m2(__riscv_vreinterpret_v_u8m2_u64m2(iq4bits), i_vec, 32)); + const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + const vint16m4_t prod = __riscv_vwmul_vv_i16m4(iq4b, q8b, 256); + q8 += 256; + + // Mask for processing 32 elements per prod register. + const vuint16m1_t p_index = __riscv_vid_v_u16m1(64); + const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16(p_index, 31, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 64)); + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -4018,6 +5956,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq4_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq4_xs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -4027,7 +5971,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4230,10 +6174,112 @@ static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m1_t suml1; + { + const int vl = 32; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 0, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 32, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 64, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 96, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 128, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint16mf2_t suml2; + { + const int vl = 16; + vuint8mf4_t tq = __riscv_vle8_v_u8mf4(x[i].qs + 32, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(tq, 3 * 1, vl), 8, vl); + vuint16mf2_t tq1 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 3, vl), 3, vl), 8, vl); + vuint16mf2_t tq2 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 9, vl), 3, vl), 8, vl); + vuint16mf2_t tq3 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 27, vl), 3, vl), 8, vl); + vuint16mf2_t tq4 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 81, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 160, vl), vl); + vint16mf2_t q81 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 176, vl), vl); + vint16mf2_t q82 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 192, vl), vl); + vint16mf2_t q83 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 208, vl), vl); + vint16mf2_t q84 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 224, vl), vl); + + vint16mf2_t sum0 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + vint16mf2_t sum1 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq1, 1, vl)), q81, vl); + vint16mf2_t sum2 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq2, 1, vl)), q82, vl); + vint16mf2_t sum3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq3, 1, vl)), q83, vl); + vint16mf2_t sum4 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq4, 1, vl)), q84, vl); + + vint16mf2_t sumi0 = __riscv_vadd_vv_i16mf2(sum0, sum1, vl); + vint16mf2_t sumi1 = __riscv_vadd_vv_i16mf2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16mf2(sum4, __riscv_vadd_vv_i16mf2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint16mf2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf4_t tq = __riscv_vlmul_trunc_v_u8mf2_u8mf4(__riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4))); + + vuint8mf4_t p = __riscv_vle8_v_u8mf4(pow, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vv_u8mf4(tq, p, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 240, vl), vl); + + suml3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + } + + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(suml1, __riscv_vmv_v_x_i32m1(0, 1), 32); + sum = __riscv_vwredsum_vs_i16mf2_i32m1(__riscv_vadd_vv_i16mf2(suml2, suml3, 16), sum, 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} #endif void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -4241,8 +6287,8 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_tq1_0_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -4250,7 +6296,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -4406,24 +6452,21 @@ static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - case 256: + default: // 256 and above ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; } #else ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4538,7 +6581,7 @@ static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 648c6fcaba7..0a7119b4e1f 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -355,6 +355,78 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; } +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + float sumf = 0; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + float summs = 0.0f; + + for (int ib = 0; ib < nb; ++ib) { + const block_q4_1 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; + + summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); + + const v128_t raw = wasm_v128_load(x0->qs); + const v128_t v0s = wasm_v128_and(raw, wasm_i8x16_splat(0x0F)); + const v128_t v1s = wasm_u8x16_shr(raw, 4); + + const v128_t ys_lo = wasm_v128_load(y0->qs); + const v128_t ys_hi = wasm_v128_load(y0->qs + 16); + + const v128_t v0s_l = wasm_u16x8_extend_low_u8x16(v0s); + const v128_t v0s_h = wasm_u16x8_extend_high_u8x16(v0s); + const v128_t ylo_l = wasm_i16x8_extend_low_i8x16(ys_lo); + const v128_t ylo_h = wasm_i16x8_extend_high_i8x16(ys_lo); + const v128_t v1s_l = wasm_u16x8_extend_low_u8x16(v1s); + const v128_t v1s_h = wasm_u16x8_extend_high_u8x16(v1s); + const v128_t yhi_l = wasm_i16x8_extend_low_i8x16(ys_hi); + const v128_t yhi_h = wasm_i16x8_extend_high_i8x16(ys_hi); + + const v128_t acc = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v0s_l, ylo_l), + wasm_i32x4_dot_i16x8(v0s_h, ylo_h)), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v1s_l, yhi_l), + wasm_i32x4_dot_i16x8(v1s_h, yhi_h))); + + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul( + wasm_f32x4_convert_i32x4(acc), + wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + + *s = sumf; + +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + UNUSED(sumf); + + ggml_vec_dot_q4_1_q8_1_generic( + n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index cd5c61a8187..eb8341c9aec 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1912,6 +1912,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_3d(params, tensor); } break; + case GGML_OP_COL2IM_1D: + { + ggml_compute_forward_col2im_1d(params, tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor); @@ -2343,6 +2347,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONV_2D: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: + case GGML_OP_COL2IM_1D: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { @@ -2943,7 +2948,7 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t K = ggml_get_op_params_i32(node, 0); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); cur = per_thread * sizeof(float) * n_tasks; } break; diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 0ecf7ae02ac..9e54b676b93 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -38,6 +38,7 @@ #include "kleidiai.h" #include "ggml-cpu.h" +#include "ggml-cpu-impl.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-threading.h" @@ -61,7 +62,8 @@ struct ggml_kleidiai_context { ggml_kleidiai_kernels * kernels_q8; int sme_thread_cap; // <= 0 means “SME disabled/unknown”; int thread_hint; // <= 0 means “no hint” -} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 }; + int chunk_multiplier; +} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1, 4 }; static const char* cpu_feature_to_string(cpu_feature f) { if (f == CPU_FEATURE_NONE) { @@ -186,8 +188,9 @@ static void init_kleidiai_context(void) { if (!initialized) { initialized = true; - const char *env_sme = getenv("GGML_KLEIDIAI_SME"); - const char *env_threads = getenv("GGML_TOTAL_THREADS"); + const char *env_sme = getenv("GGML_KLEIDIAI_SME"); + const char *env_threads = getenv("GGML_TOTAL_THREADS"); + const char *env_chunk_mult = getenv("GGML_KLEIDIAI_CHUNK_MULTIPLIER"); const bool cpu_has_sme = ggml_cpu_has_sme(); size_t detected_smcus = 0; @@ -204,6 +207,14 @@ static void init_kleidiai_context(void) { } } + if (env_chunk_mult) { + bool ok = false; + int multiplier = parse_uint_env(env_chunk_mult, "GGML_KLEIDIAI_CHUNK_MULTIPLIER", &ok); + if (ok && multiplier > 0) { + ctx.chunk_multiplier = multiplier; + } + } + // SME policy: // - If CPU doesn't support SME: SME always off. // - Else: @@ -296,6 +307,50 @@ static inline size_t align_up(size_t value, size_t alignment) { return remainder == 0 ? value : value + (alignment - remainder); } +static inline size_t gcd_size(size_t a, size_t b) { + while (b != 0) { + const size_t t = a % b; + a = b; + b = t; + } + return a; +} + +static inline bool lcm_size(size_t a, size_t b, size_t & result) { + if (a == 0 || b == 0) { + result = 0; + return false; + } + const size_t g = gcd_size(a, b); + const size_t q = a / g; + if (q > SIZE_MAX / b) { + return false; + } + result = q * b; + return true; +} + +static inline size_t ceil_div_size(size_t a, size_t b) { + return b == 0 ? 0 : (a + b - 1) / b; +} + +struct kleidiai_block_args { + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; +}; + +static inline kleidiai_block_args kleidiai_get_block_args(ggml_type rhs_type) { + switch (rhs_type) { + case GGML_TYPE_Q4_0: + return { QK4_0, QK4_0, QK4_0 }; + case GGML_TYPE_Q8_0: + return { 0, 0, QK8_0 }; + default: + return { 0, 0, 0 }; + } +} + static inline bool kleidiai_pack_fallback_allowed() { if (ctx.sme_thread_cap <= 0) { return false; @@ -746,8 +801,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t n_step; size_t lhs_packed_size; size_t lhs_offset; - size_t n_offset; - size_t n_cols; + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; + size_t lhs_packed_offset0; int assigned_threads; int thread_begin; int thread_end; @@ -772,6 +829,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { continue; } + const kleidiai_block_args block_args = kleidiai_get_block_args(kernels->rhs_type); + runtime[runtime_count] = { slot, kernels, @@ -784,7 +843,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { kinfo->get_n_step(), 0, 0, - 0, + block_args.lhs_bl, + block_args.rhs_bl, + block_args.pack_bl, 0, 0, 0, @@ -795,45 +856,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { } if (runtime_count == 0) { - ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst); - if (!fallback) { - return false; - } - kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm; - lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info; - rhs_packing_info * rinfo = &fallback->rhs_info; - if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || - !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset || - !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) { - return false; - } - kernel_chain[0] = fallback; - runtime[0] = { - 0, - fallback, - kinfo, - linfo, - kinfo->get_mr(), - kinfo->get_nr(), - kinfo->get_kr(), - kinfo->get_sr(), - kinfo->get_n_step(), - 0, - 0, - 0, - 0, - 0, - 0, - 0, - nullptr - }; - size_t rhs_size_fallback = 0; - const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback); - if (!rhs_base) { - rhs_base = static_cast<const uint8_t *>(src0->data); - } - runtime[0].rhs_base = rhs_base; - runtime_count = 1; + GGML_LOG_WARN("kleidiai: no runtime kernel slot available for supported op %s\n", dst->name); + return false; } const int nth_total = params->nth > 0 ? params->nth : 1; @@ -846,6 +870,13 @@ class tensor_traits : public ggml::cpu::tensor_traits { break; } } + int non_sme_slot = -1; + for (int i = 0; i < runtime_count; ++i) { + if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) != CPU_FEATURE_SME) { + non_sme_slot = i; + break; + } + } const int sme_cap_limit = ctx.sme_thread_cap; const bool use_hybrid = sme_cap_limit > 0 && @@ -864,12 +895,15 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (!hybrid_enabled) { int chosen_slot = 0; if (too_small_for_hybrid && sme_slot != -1) { - chosen_slot = sme_slot; + chosen_slot = nth_total > sme_cap_limit && non_sme_slot != -1 ? non_sme_slot : sme_slot; } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) { chosen_slot = 1; } if (chosen_slot != 0 && chosen_slot < runtime_count) { runtime[0] = runtime[chosen_slot]; + runtime[0].assigned_threads = 0; + runtime[0].thread_begin = 0; + runtime[0].thread_end = 0; } runtime_count = runtime_count > 0 ? 1 : 0; @@ -896,6 +930,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; int fallback_count = 0; + // The current hybrid chain is bounded to SME + one non-SME fallback slot. + GGML_ASSERT(GGML_KLEIDIAI_MAX_KERNEL_SLOTS == 2); for (int i = 0; i < runtime_count; ++i) { if (i == sme_slot) { continue; @@ -952,73 +988,67 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t cursor = 0; for (int i = 0; i < runtime_count; ++i) { - const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type; - const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0; - runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr); + runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, runtime[i].pack_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); runtime[i].lhs_offset = cursor; + runtime[i].lhs_packed_offset0 = runtime[i].lhs_info->get_packed_offset_ex(0, k, runtime[i].lhs_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); cursor += runtime[i].lhs_packed_size; } GGML_ASSERT(cursor <= params->wsize); uint8_t * scratch = static_cast<uint8_t *>(params->wdata); - size_t assigned_cols = 0; - uint64_t weighted_total = 0; - if (runtime_count > 1 && sme_slot != -1) { - for (int i = 0; i < runtime_count; ++i) { - const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; - weighted_total += (uint64_t)runtime[i].assigned_threads * weight; - } - } + size_t common_step = 1; for (int i = 0; i < runtime_count; ++i) { - runtime[i].n_offset = assigned_cols; if (runtime[i].assigned_threads == 0) { - runtime[i].n_cols = 0; continue; } - const size_t remaining_cols = n - assigned_cols; - if (remaining_cols == 0) { - runtime[i].n_cols = 0; - continue; - } - const size_t step = runtime[i].n_step ? runtime[i].n_step : 1; - size_t target = 0; - if (weighted_total > 0) { - const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; - target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total); - } else { - target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total); - } - target = std::min(target, remaining_cols); - size_t aligned = round_down(target, step); - if (aligned == 0 && remaining_cols >= step) { - aligned = step; + size_t next_step = 0; + if (!lcm_size(common_step, runtime[i].n_step ? runtime[i].n_step : 1, next_step)) { + return false; } - runtime[i].n_cols = aligned; - assigned_cols += aligned; + common_step = next_step; } - - if (assigned_cols < n) { - for (int i = runtime_count - 1; i >= 0; --i) { - if (runtime[i].assigned_threads > 0) { - runtime[i].n_cols += n - assigned_cols; - break; - } - } + GGML_ASSERT(common_step > 0); + + const bool disable_chunking = ggml_is_numa(); + const size_t chunk_multiplier = std::max(1, ctx.chunk_multiplier); + const size_t chunk_divisor = (nth_total == 1 || disable_chunking) ? (size_t)nth_total : (size_t)nth_total * chunk_multiplier; + size_t chunk_cols = align_up(std::max<size_t>(1, ceil_div_size(n, chunk_divisor)), common_step); + if (chunk_cols == 0) { + chunk_cols = common_step; } + // If common_step is larger than n, the loop below runs one valid tail chunk + // with cols == n. + const size_t nchunk_size = std::max<size_t>(1, ceil_div_size(n, chunk_cols)); + GGML_ASSERT(nchunk_size <= (size_t)INT_MAX); + const int nchunk = (int)nchunk_size; const size_t dst_stride = dst->nb[1]; + auto run_chunk = [&](runtime_slot & slot, size_t global_start, size_t cols, uint8_t * dst_batch_base) { + const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot.rhs_bl); + const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); + + const uint8_t * lhs_ptr = scratch + slot.lhs_offset + slot.lhs_packed_offset0; + const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; + float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset); + + slot.kernel->run_kernel_ex(m, cols, k, slot.rhs_bl, + lhs_ptr, + rhs_ptr, + dst_ptr, + dst_stride, + sizeof(float), + -FLT_MAX, + FLT_MAX); + }; + for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) { const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2]; uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2]; if (runtime[local_slot].assigned_threads > 0) { runtime_slot & slot = runtime[local_slot]; - const ggml_type slot_rhs_type = slot.kernels->rhs_type; - const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr); int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads; max_threads = std::max<int64_t>(1, max_threads); @@ -1031,8 +1061,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_start = (int64_t)local_ith * num_m_per_thread0; const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; - const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); - const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); + const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); + const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0; int64_t remaining = m_count; @@ -1049,7 +1079,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; void * dst_ptr = lhs_packed + dst_off; - slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); + slot.lhs_info->pack_func_ex(take, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); cur += take; remaining -= take; @@ -1057,49 +1087,29 @@ class tensor_traits : public ggml::cpu::tensor_traits { } } + if (ith_total == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth_total); + } + + // Publishes both LHS packing and the initialized dynamic chunk queue. ggml_barrier(params->threadpool); runtime_slot & slot = runtime[local_slot]; - if (slot.n_cols > 0 && slot.assigned_threads > 0) { - int64_t active_threads = slot.assigned_threads; - const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads; - if (max_threads > 0) { - active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads)); + int current_chunk = ith_total; + while (current_chunk < nchunk) { + const size_t global_start = (size_t)current_chunk * chunk_cols; + if (global_start >= n) { + break; } - active_threads = std::max<int64_t>(1, active_threads); - - if (local_ith < active_threads) { - const size_t step = slot.n_step ? slot.n_step : 1; - const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step); - const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0; - const size_t local_start = (size_t)local_ith * chunk0; - const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0; - - if (cols > 0) { - const ggml_type slot_rhs_type = slot.kernels->rhs_type; - const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; - const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; - const size_t global_start = slot.n_offset + local_start; - const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); - const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg); - const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); - - const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset; - const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; - float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset); - - slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg, - lhs_ptr, - rhs_ptr, - dst_ptr, - dst_stride, - sizeof(float), - -FLT_MAX, - FLT_MAX); - } + + const size_t cols = std::min(chunk_cols, n - global_start); + if (cols > 0) { + // KleidiAI GEMM/GEMV kernels accept arbitrary final tail widths; + // only non-tail chunks are guaranteed to be n_step-aligned. + run_chunk(slot, global_start, cols, dst_batch_base); } + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } if (batch_idx != ne12 - 1) { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7485ba4fc86..74611dce7f1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2235,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg } } +static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0)); + + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const auto [ir0, ir1] = get_thread_range(params, dst); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne2*ne1); + const int64_t i02 = (ir - i03*ne2*ne1)/ne1; + const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + + ggml_vec_set_f16(ne0, dst_ptr, c); + } +} + void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) { - ggml_compute_forward_fill_f32(params, dst); + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fill_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_fill_f16(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type)); + } + } } // ggml_compute_tri @@ -3974,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32( // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) - ggml_vec_cpy_f32 (ne00, dx, x); - // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - ggml_vec_acc_f32 (ne00, dx, dz); - ggml_vec_scale_f32(ne00, dx, rrms); + // dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms + // note: https://github.com/ggml-org/ggml/issues/1491 + const float scale_x = (float) (-sum_xdz) / sum_eps; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms; + } } } } @@ -6696,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; // adding size avoids negative number weirdness } +// ggml_compute_forward_col2im_1d +// +// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC] +// where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs. +// Parallelized over the time axis so the split stays balanced whatever OC is. +// Supports F32, F16, BF16 input/output (same type), F32 accumulator. + +template <typename elem_t> +static void ggml_compute_forward_col2im_1d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src = dst->src[0]; // [K*OC, T_in] + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t OC = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + + const int64_t K_OC = src->ne[0]; + const int64_t T_in = src->ne[1]; + const int64_t K = K_OC / OC; + const int64_t T_out = dst->ne[0]; + + const elem_t * col_data = (const elem_t *) src->data; + elem_t * dst_data = (elem_t *) dst->data; + + const int ith = params->ith; + const int nth = params->nth; + + // Parallelize over the time axis: the split stays balanced whatever OC is, + // down to OC = 1 for mono audio, and threads read disjoint column bands + const int64_t dr = (T_out + nth - 1) / nth; + const int64_t it0 = dr * ith; + const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out; + + for (int64_t oc = 0; oc < OC; oc++) { + for (int64_t t_out = it0; t_out < it1; t_out++) { + const int64_t t_abs = t_out + p0; // absolute position in uncropped signal + // Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K + int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s) + if (t_in_min < 0) t_in_min = 0; + int64_t t_in_max = t_abs / s0; + if (t_in_max >= T_in) t_in_max = T_in - 1; + + float sum = 0.0f; + for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) { + int64_t k = t_abs - t_in * s0; + if (k >= 0 && k < K) { + // col layout: [K*OC, T_in], element (oc*K+k, t_in) + sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]); + } + } + // dst layout: [T_out, OC], element (t_out, oc) + dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum); + } + } +} + +void ggml_compute_forward_col2im_1d( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break; + case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break; + case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break; + default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type); + } +} + // ggml_compute_forward_conv_2d @@ -8921,7 +9027,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( k->type == v->type && neq1 >= Q_TILE_SZ); #ifdef GGML_SIMD - use_tiled &= (DV % GGML_F32_EPR == 0); +#if defined(__ARM_FEATURE_SVE) + const int64_t f32_epr = svcntw(); +#else + const int64_t f32_epr = GGML_F32_EPR; +#endif + use_tiled &= (DV % f32_epr == 0); #endif int current_chunk = ith; @@ -10513,11 +10624,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int64_t K = src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int64_t K = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(K >= 1); - // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) - const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + // per-seq stride in floats (seq s starts at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[3] / sizeof(float); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; @@ -10533,9 +10644,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int64_t shift = n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. const float * state_in_base = (const float *)src_state->data; @@ -10563,7 +10673,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( : state_out_base + (iv3 * H + iv1) * S_v * S_v; // copy input state into the working buffer and operate in-place - // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride. const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); @@ -10616,7 +10726,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token if (K > 1) { - const int64_t target_slot = t - shift; + const int64_t target_slot = n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state_o = state_out_base + target_slot * state_size_per_snap + (iv3 * H + iv1) * S_v * S_v; @@ -11324,7 +11434,11 @@ static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, gg // Scalar passes #if defined(GGML_SIMD) +#if defined(__ARM_FEATURE_SVE) + const int step = svcntw(); +#else const int step = GGML_F32_EPR; +#endif #else const int step = n; #endif diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 7398e561894..a8e18c716db 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -68,6 +68,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_col2im_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 0deda930985..62e687201ef 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -1125,25 +1125,12 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { #define GGML_F16_EPR 4 static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { - float tmp[4]; - - tmp[0] = GGML_CPU_FP16_TO_FP32(x[0]); - tmp[1] = GGML_CPU_FP16_TO_FP32(x[1]); - tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]); - tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]); - - return (__m128)__lsx_vld(tmp, 0); + return __lsx_vfcvtl_s_h(__lsx_vld((const void *)x, 0)); } static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { - float arr[4]; - - __lsx_vst(y, arr, 0); - - x[0] = GGML_CPU_FP32_TO_FP16(arr[0]); - x[1] = GGML_CPU_FP32_TO_FP16(arr[1]); - x[2] = GGML_CPU_FP32_TO_FP16(arr[2]); - x[3] = GGML_CPU_FP32_TO_FP16(arr[3]); + __m128i a = __lsx_vfcvt_h_s(y, y); + memcpy(x, &a, sizeof(ggml_fp16_t) * 4); } #define GGML_F32Cx4 __m128 diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index d0e4001338a..67b6b05cac8 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -273,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; //get vector length - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - const int np= (n & ~(ggml_f16_step - 1)); - svfloat16_t sum1 = svdup_n_f16(0.0f); - svfloat16_t sum2 = svdup_n_f16(0.0f); - svfloat16_t sum3 = svdup_n_f16(0.0f); - svfloat16_t sum4 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1); - - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2); - - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3); - - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4); - - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 8 * ggml_f16_epr; + const int np = n - (n % ggml_f16_step); + const int np2 = n - (n % ggml_f16_epr); + + svfloat32_t sum1_lo = svdup_n_f32(0.0f); + svfloat32_t sum1_hi = svdup_n_f32(0.0f); + svfloat32_t sum2_lo = svdup_n_f32(0.0f); + svfloat32_t sum2_hi = svdup_n_f32(0.0f); + svfloat32_t sum3_lo = svdup_n_f32(0.0f); + svfloat32_t sum3_hi = svdup_n_f32(0.0f); + svfloat32_t sum4_lo = svdup_n_f32(0.0f); + svfloat32_t sum4_hi = svdup_n_f32(0.0f); - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6); - - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7); - - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8); + for (int i = 0; i < np; i += ggml_f16_step) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3)); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7)); } - const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8 - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry); + for (int i = np; i < np2; i += ggml_f16_epr) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0)); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2)); + const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2)); - sum1 = svmad_f16_x(pg, hx, hy, sum1); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry); } - GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4); + + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi); + sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo); + sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi); + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi); + + sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi); #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) int vl = __riscv_vsetvlmax_e32m2(); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index bcd68da9aa9..5de9cb5b7e0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -14,6 +14,35 @@ // floating point type used to accumulate sums typedef double ggml_float; +#if defined(__ARM_FEATURE_SVE) +inline static void ggml_sve_f16_fma_widened( + svfloat32_t * acc_lo, + svfloat32_t * acc_hi, + svfloat16_t x, + svfloat16_t y) { +#if defined(__ARM_FEATURE_SVE2) + *acc_lo = svmlalb_f32(*acc_lo, x, y); + *acc_hi = svmlalt_f32(*acc_hi, x, y); +#else + // Plain SVE fallback path if SVE2 instructions not available + svfloat16_t x_even = svtrn1_f16(x, x); + svfloat16_t x_odd = svtrn2_f16(x, x); + + svfloat16_t y_even = svtrn1_f16(y, y); + svfloat16_t y_odd = svtrn2_f16(y, y); + + svbool_t pg = svptrue_b32(); + + *acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even)); + *acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd)); +#endif +} + +inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) { + return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi)); +} +#endif + #define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 @@ -122,108 +151,61 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - int np = (n & ~(ggml_f16_step - 1)); - - svfloat16_t sum_00 = svdup_n_f16(0.0f); - svfloat16_t sum_01 = svdup_n_f16(0.0f); - svfloat16_t sum_02 = svdup_n_f16(0.0f); - svfloat16_t sum_03 = svdup_n_f16(0.0f); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 2 * ggml_f16_epr; + int np = n - (n % ggml_f16_step); + int np2 = n - (n % ggml_f16_epr); - svfloat16_t sum_10 = svdup_n_f16(0.0f); - svfloat16_t sum_11 = svdup_n_f16(0.0f); - svfloat16_t sum_12 = svdup_n_f16(0.0f); - svfloat16_t sum_13 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f); for (int i = 0; i < np; i += ggml_f16_step) { - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements - - ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 - ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); - - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements + const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); - ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0); - ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); - ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); - - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - - ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3); - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4); - ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4); - - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - - ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4); - - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5); - ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4); - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5); - - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - - ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5); - - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6); - ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6); - - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - - ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6); - - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7); - ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7); - - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - - ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7); - - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8); - ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8); + ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1); + ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + for (int i = np; i < np2; i += ggml_f16_epr) { + const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0); - sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry); - rx = GGML_F16x_VEC_LOAD(x[1] + k, 0); - sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); - svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); + const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00); - sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay); } - GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); - GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + + svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo); + svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi); + svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo); + svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi); + sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi); + sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi); np = n; #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d396c2d7441..e6e50e04119 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -7,6 +7,7 @@ #include <cstdint> #include <cstdlib> #include <memory> +#include <mutex> #if defined(GGML_USE_HIP) #define GGML_COMMON_DECL_HIP @@ -110,13 +111,14 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 -// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 12.0 and excludes HIP/MUSA. -// (The C++ cudaLaunchKernelEx template errors "undefined" on the CUDA 11.8 Windows -// toolchain; it builds clean on 12.x. Upstream only CI-tests CUDA 12.x. QVAC-19386.) +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8. +// However, this has been bugged in CTK < 12.3 for MSVC builds, see +// https://github.com/ggml-org/llama.cpp/pull/22522#discussion_r3302393293 // __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 12000 +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && \ + (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) # define GGML_CUDA_USE_PDL -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 12000 +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) static __device__ __forceinline__ void ggml_cuda_pdl_sync() { #if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER @@ -1551,8 +1553,70 @@ struct ggml_cuda_pdl_config { ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; }; + +static bool ggml_cuda_kernel_can_use_pdl(const void * kernel) { + const int device = ggml_cuda_get_device(); + + struct cache_key { + int device; + const void * kernel; + + bool operator==(const cache_key & other) const { return device == other.device && kernel == other.kernel; } + }; + + struct cache_key_hash { + // MurmurHash3 mixing function for better hash distribution (vs. just std::hash which in some implementations simply returns the identity) + static size_t hash_mix(size_t x) { + std::uint64_t y = x; + const std::uint64_t m = 0xe9846af9b1a615d; + + y ^= y >> 32; + y *= m; + y ^= y >> 32; + y *= m; + y ^= y >> 28; + + return static_cast<size_t>(y); + } + + size_t operator()(const cache_key & key) const { + // Use a nonzero seed to avoid mapping all-zero keys to zero + size_t h = 42; + h = hash_mix(h + key.device); + h = hash_mix(h + reinterpret_cast<size_t>(key.kernel)); + return h; + } + }; + + static std::mutex cache_mutex; + static std::unordered_map<cache_key, bool, cache_key_hash> cache; + + const cache_key key = { device, kernel }; + std::lock_guard<std::mutex> lock(cache_mutex); + const auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + cudaFuncAttributes attr = {}; + CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel)); + + // PDL device-side primitives are emitted only for PTX versions >= 90. + // We have to guard on a loaded kernel's PTX version so a kernel forward-JIT'ed + // from pre-Hopper PTX to a Hopper-or-newer GPU does not opt into PDL. + const bool can_use_pdl = attr.ptxVersion >= 90; + cache.emplace(key, can_use_pdl); + return can_use_pdl; +} + #endif //defined(GGML_CUDA_USE_PDL) +// PDL and __restrict__ need to be mutually exclusive, see https://github.com/ggml-org/llama.cpp/pull/24030 +# if (defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER) +# define GGML_CUDA_RESTRICT +# else +# define GGML_CUDA_RESTRICT __restrict__ +# endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER template<typename Kernel, typename... Args> static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { @@ -1563,8 +1627,7 @@ static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_ke return env == nullptr || std::atoi(env) != 0; }(); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) { + if (env_pdl_enabled && ggml_cuda_kernel_can_use_pdl(reinterpret_cast<const void *>(kernel))) { auto pdl_cfg = ggml_cuda_pdl_config(launch_params); CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... )); diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index adba4d522a4..8d557092b2b 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,16 +1,18 @@ #include "concat.cuh" +#include <stdint.h> + // contiguous kernels -template <int dim> -static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2) { +template <typename T, int dim> +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_cont(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); const int64_t n = ne0 * ne1 * ne2; @@ -50,37 +52,37 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont } } -static void concat_f32_cuda(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int dim, - cudaStream_t stream) { +template <typename T> +static void concat_cont_cuda(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { const int64_t n = ne0 * ne1 * ne2; const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); - ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + ggml_cuda_kernel_launch(concat_cont<T, 0>, launch_params, x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_cont<1> - <<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<T, 1><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_cont<2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<T, 2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) -template <int dim> +template <typename T, int dim> static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) - concat_f32_non_cont( + concat_non_cont( const char * src0, const char * src1, char * dst, @@ -107,61 +109,49 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) uint64_t nb0, uint64_t nb1, uint64_t nb2, - uint64_t nb3){ + uint64_t nb3) { static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - const float * x; + const T * x; for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + x = (const T *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); } else { if constexpr (dim == 0) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + i1*nb11 + (i0 - ne00)*nb10); } else if constexpr (dim == 1) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + (i1 - ne01)*nb11 + i0*nb10); } else if constexpr (dim == 2) { - x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + (i2 - ne02)*nb12 + i1*nb11 + i0*nb10); } else if constexpr (dim == 3) { - x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + (i3 - ne03)*nb13 + i2*nb12 + i1*nb11 + i0*nb10); } } - float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + T * y = (T *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } } - -void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - cudaStream_t stream = ctx.stream(); - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - +template <typename T> +static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, int dim, cudaStream_t stream) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - - float * dst_d = (float *)dst->data; + const T * src0_d = (const T *) src0->data; + const T * src1_d = (const T *) src1->data; + T * dst_d = (T *) dst->data; if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + concat_cont_cuda( + src0_d + i3*(src0->nb[3] / sizeof(T)), + src1_d + i3*(src1->nb[3] / sizeof(T)), + dst_d + i3*( dst->nb[3] / sizeof(T)), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } @@ -169,13 +159,13 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const size_t size0 = ggml_nbytes(src0); const size_t size1 = ggml_nbytes(src1); - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); auto launch_kernel = [&](auto dim) { - concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( + concat_non_cont<T, dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], @@ -203,3 +193,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } } + +void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(!ggml_is_quantized(src0->type)); + GGML_ASSERT(ggml_blck_size(src0->type) == 1); + + switch (ggml_type_size(src0->type)) { + case 1: + concat_cuda<uint8_t>(src0, src1, dst, dim, stream); + break; + case 2: + concat_cuda<uint16_t>(src0, src1, dst, dim, stream); + break; + case 4: + concat_cuda<uint32_t>(src0, src1, dst, dim, stream); + break; + case 8: + concat_cuda<uint64_t>(src0, src1, dst, dim, stream); + break; + default: + GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type)); + break; + } +} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index debcb6e5447..8dfa51ad1e8 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)( typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +struct ggml_cuda_flash_attn_ext_f16_extra_data { + uintptr_t K; + uintptr_t V; + uintptr_t end; +}; + +static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data( + const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + ggml_cuda_flash_attn_ext_f16_extra_data data = {}; + data.end = (uintptr_t) dst->data + ggml_nbytes(dst); + + if (need_f16_K && K->type != GGML_TYPE_F16) { + data.end = GGML_PAD(data.end, 128); + data.K = data.end; + data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16); + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + data.V = data.K; + } else { + data.end = GGML_PAD(data.end, 128); + data.V = data.end; + data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16); + } + } + + return data; +} + template <int D, int nthreads> static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -678,8 +718,8 @@ static __global__ void flash_attn_mask_to_KV_max( template<int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_uniform( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int ne12, const int nblocks_stream_k, const int gqa_ratio, @@ -689,6 +729,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; ggml_cuda_pdl_lc(); + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -760,8 +802,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( template <int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_general( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int gqa_ratio, const int total_work, @@ -769,6 +811,8 @@ static __global__ void flash_attn_stream_k_fixup_general( const uint3 fd_iter_k_j_z, const uint3 fd_iter_k_j, const uint3 fd_iter_k) { + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -867,11 +911,14 @@ static __global__ void flash_attn_stream_k_fixup_general( template<int D> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst, + const float * VKQ_parts_ptr, + const float2 * VKQ_meta_ptr, + float * dst_ptr, const int parallel_blocks) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT VKQ_parts = VKQ_parts_ptr; + const float2 * GGML_CUDA_RESTRICT VKQ_meta = VKQ_meta_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -952,8 +999,9 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - ggml_cuda_pool_alloc<half> K_f16(pool); - ggml_cuda_pool_alloc<half> V_f16(pool); + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V); + ggml_cuda_pool_alloc<int> KV_max(pool); ggml_cuda_pool_alloc<float> dst_tmp(pool); ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); @@ -972,10 +1020,11 @@ void launch_fattn( const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - K_f16.alloc(ggml_nelements(K)); + GGML_ASSERT(f16_extra.K != 0); + half * K_f16 = (half *) f16_extra.K; if (ggml_is_contiguously_allocated(K)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + to_fp16(K_data, K_f16, ggml_nelements(K), main_stream); nb11 = nb11*bs*sizeof(half)/ts; nb12 = nb12*bs*sizeof(half)/ts; @@ -986,13 +1035,13 @@ void launch_fattn( const int64_t s01 = nb11 / ts; const int64_t s02 = nb12 / ts; const int64_t s03 = nb13 / ts; - to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); nb11 = K->ne[0] * sizeof(half); nb12 = K->ne[1] * nb11; nb13 = K->ne[2] * nb12; } - K_data = (char *) K_f16.ptr; + K_data = (char *) K_f16; } if (need_f16_V && V->type != GGML_TYPE_F16) { @@ -1005,11 +1054,12 @@ void launch_fattn( const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); - V_f16.alloc(ggml_nelements(V)); + GGML_ASSERT(f16_extra.V != 0); + half * V_f16 = (half *) f16_extra.V; if (ggml_is_contiguously_allocated(V)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; + to_fp16(V_data, V_f16, ggml_nelements(V), main_stream); + V_data = (char *) V_f16; nb21 = nb21*bs*sizeof(half)/ts; nb22 = nb22*bs*sizeof(half)/ts; @@ -1020,13 +1070,13 @@ void launch_fattn( const int64_t s01 = nb21 / ts; const int64_t s02 = nb22 / ts; const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); nb21 = V->ne[0] * sizeof(half); nb22 = V->ne[1] * nb21; nb23 = V->ne[2] * nb22; } - V_data = (char *) V_f16.ptr; + V_data = (char *) V_f16; } } @@ -1153,8 +1203,8 @@ void launch_fattn( GGML_ASSERT(block_dim.x % warp_size == 0); - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); - ggml_cuda_kernel_launch(fattn_kernel, launch_params, + ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 4871b90df86..83478a02cb6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -472,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); + cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i); } } else if constexpr (oob_check) { #pragma unroll @@ -488,7 +488,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; - tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f); } } } else if constexpr (nbatch_fa < 2*warp_size) { @@ -505,7 +505,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = threadIdx.x % (warp_size/cols_per_warp); - ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i); } } else { #pragma unroll @@ -521,7 +521,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; - ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i); } } } @@ -568,7 +568,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); - constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; @@ -604,9 +603,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; - const int k0_diff = k0_stop - k0_start; if constexpr (nstages <= 1) { + const int k0_diff = k0_stop - k0_start; constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check> (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); @@ -640,6 +639,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } else { + constexpr int stride_tile_Q = DKQ/2 + 4; #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -954,9 +954,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); const int i0_stop = i0_start + 2*nbatch_V2; - const int i0_diff = i0_stop - i0_start; if constexpr (nstages <= 1) { + const int i0_diff = i0_stop - i0_start; if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check> @@ -1703,14 +1703,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view> __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -1726,6 +1726,14 @@ static __global__ void flash_attn_ext_f16( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_sync(); // TODO optimize placement #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1871,7 +1879,7 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index fac76f13593..0a099810e14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -788,14 +788,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter( template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -810,6 +810,14 @@ static __global__ void flash_attn_tile( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: @@ -1126,7 +1134,7 @@ static __global__ void flash_attn_tile( } } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index b0a6cf67f1a..69dd9368624 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -19,14 +19,14 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) static __global__ void flash_attn_ext_vec( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -42,6 +42,14 @@ static __global__ void flash_attn_ext_vec( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -506,7 +514,7 @@ static __global__ void flash_attn_ext_vec( dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 4b6f6501094..6850716fc0d 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -24,14 +24,14 @@ namespace wmma = rocwmma; template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap> __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -46,6 +46,14 @@ static __global__ void flash_attn_ext_f16( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -494,7 +502,7 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1c7777e8a71..d6c501b1d7e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -537,6 +537,41 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_TILE; } +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst); + + bool need_f16_K = false; + bool need_f16_V = false; + + switch (kernel) { + case BEST_FATTN_KERNEL_TILE: + case BEST_FATTN_KERNEL_WMMA_F16: + case BEST_FATTN_KERNEL_MMA_F16: + need_f16_K = true; + need_f16_V = true; + break; + case BEST_FATTN_KERNEL_VEC: + need_f16_K = K->type == GGML_TYPE_F32; + need_f16_V = V->type == GGML_TYPE_F32; + break; + case BEST_FATTN_KERNEL_NONE: + break; + } + + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V); + + return f16_extra.end - (uintptr_t) dst->data; +} + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_set_device(ctx.device); switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh index 78705d59951..f9a7e15fbd6 100644 --- a/ggml/src/ggml-cuda/fattn.cuh +++ b/ggml/src/ggml-cuda/fattn.cuh @@ -3,3 +3,5 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); + +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fwht.cu b/ggml/src/ggml-cuda/fwht.cu new file mode 100644 index 00000000000..184dc254c72 --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cu @@ -0,0 +1,101 @@ +#include "common.cuh" +#include "fwht.cuh" + +template <int N> +__launch_bounds__(4*ggml_cuda_get_physical_warp_size(), 1) +__global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, const float scale) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + const int64_t r = (int64_t) blockIdx.x * blockDim.y + threadIdx.y; + + if (r >= n_rows) { + return; + } + + src += r * N; + dst += r * N; + + static constexpr int el_w = N / warp_size; + float reg[el_w]; + const int lane = threadIdx.x; + + ggml_cuda_pdl_sync(); +#pragma unroll + for (int i = 0; i < el_w; ++i) { + reg[i] = src[i * warp_size + lane] * scale; + } + +#pragma unroll + for (int h = 1; h < warp_size; h *= 2) { +#pragma unroll + for (int j = 0; j < el_w; j++) { + const float val = reg[j]; + const float val2 = __shfl_xor_sync(0xFFFFFFFF, val, h, warp_size); + + reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + } + } + +#pragma unroll + for (int h = warp_size; h < N; h *= 2) { + const int step = h / warp_size; +#pragma unroll + for (int j = 0; j < el_w; j += 2 * step) { +#pragma unroll + for (int k = 0; k < step; k++) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + +#pragma unroll + for (int i = 0; i < el_w; ++i) { + dst[i * warp_size + lane] = reg[i]; + } +} + +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src, dst)); + if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) { + return false; + } + const int n = src->ne[0]; + const int64_t rows = ggml_nrows(src); + + const float * src_d = (const float *) src->data; + float * dst_d = (float *) dst->data; + + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int rows_per_block = 4; + + const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; + + cudaStream_t stream = ctx.stream(); + dim3 grid_dims(num_blocks, 1, 1); + dim3 block_dims(warp_size, rows_per_block, 1); + const ggml_cuda_kernel_launch_params launch_params = + ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); + + const float scale = 1 / sqrtf(n); + + switch (n) { + case 64: + ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); + return true; + case 128: + ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); + return true; + case 256: + ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); + return true; + case 512: + ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/fwht.cuh b/ggml/src/ggml-cuda/fwht.cuh new file mode 100644 index 00000000000..cf3df94cafa --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" + +// Returns whether the Fast Walsh-Hadamard transform could be used. +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 018d5d37d47..a547360eb06 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -39,11 +39,10 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; - const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output state += state_out_offset; curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; @@ -61,10 +60,6 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; - for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -148,7 +143,10 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; if constexpr (keep_rs_t) { - const int target_slot = t - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -286,8 +284,8 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 457b695eb2a..eb157b8baf2 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -42,7 +42,7 @@ static __global__ void k_get_rows( template<typename src0_t, typename dst_t> static __global__ void k_get_rows_float( - const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const src0_t * src0_ptr, const int32_t * src1_ptr, dst_t * dst_ptr, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, @@ -50,6 +50,9 @@ static __global__ void k_get_rows_float( const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { ggml_cuda_pdl_lc(); + const src0_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const int32_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e25be3592fd..61041bdc16b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -24,6 +24,7 @@ #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" +#include "ggml-cuda/fwht.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmf.cuh" @@ -621,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() { // cuda buffer +struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; + std::string pci_bus_id; + int op_offload_min_batch_size; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::mutex device_mutex; + int active_count = 0; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +}; + struct ggml_backend_cuda_buffer_context { int device; void * dev_ptr = nullptr; @@ -638,6 +651,13 @@ struct ggml_backend_cuda_buffer_context { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete ctx; } @@ -790,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } @@ -800,7 +826,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty } static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - size_t size = ggml_nbytes(tensor); + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context; + + size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT + ? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor) + : ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; if (ggml_is_quantized(tensor->type)) { @@ -811,8 +841,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t } return size; - - GGML_UNUSED(buft); } static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { @@ -1487,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { } static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1495,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) { return nullptr; } + ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0. + void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { @@ -1520,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm buffer->buft = buft; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return buffer; } @@ -2569,6 +2611,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { @@ -2577,6 +2620,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2594,6 +2638,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; + const int32_t hint = ggml_get_op_params_i32(dst, 1); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) { + return; + } + if (!split && use_mul_mat_vec_f) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) @@ -3130,6 +3179,12 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { static void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete cuda_ctx; delete backend; } @@ -4861,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { // backend device -struct ggml_backend_cuda_device_context { - int device; - std::string name; - std::string description; - std::string pci_bus_id; - int op_offload_min_batch_size; -}; - static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->name.c_str(); @@ -4957,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::lock_guard<std::mutex> lock(ctx->device_mutex); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemGetInfo(free, total)); @@ -4983,11 +5035,24 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } #endif // defined(__linux__) +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA + // context that permanently consumes VRAM. Reset the device to free it. + if (ctx->active_count == 0) { + CUDA_CHECK(cudaDeviceReset()); + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { - GGML_UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device)); + + return prop.integrated + ? GGML_BACKEND_DEVICE_TYPE_IGPU + : GGML_BACKEND_DEVICE_TYPE_GPU; } static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { @@ -5280,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + !ggml_is_quantized(src0_type) && + ggml_blck_size(src0_type) == 1 && + (ggml_type_size(src0_type) == 1 || + ggml_type_size(src0_type) == 2 || + ggml_type_size(src0_type) == 4 || + ggml_type_size(src0_type) == 8); } break; case GGML_OP_CONV_TRANSPOSE_1D: { @@ -5671,13 +5744,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) { return nullptr; } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device); + ggml_backend_t cuda_backend = new ggml_backend { /* .guid = */ ggml_backend_cuda_guid(), /* .iface = */ ggml_backend_cuda_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .device = */ dev, /* .context = */ ctx, }; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return cuda_backend; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a8d54c95a..d55cc1ec7b5 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -91,7 +91,7 @@ static __global__ void mul_mat_f( const int row0 = blockIdx.x * rows_per_block; int expert_idx = 0; - int col_base = 0; + [[maybe_unused]] int col_base = 0; const int channel_dst = has_ids ? 0 : blockIdx.y; @@ -122,12 +122,12 @@ static __global__ void mul_mat_f( ids += col_offset * stride_row_id; } - const float2 * y2 = (const float2 *) y; + [[maybe_unused]] const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; char * shmem_base = data_mmv; - int * slot_map = (int *) shmem_base; + [[maybe_unused]] int * slot_map = (int *) shmem_base; char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; tile_C C[ntA][ntB]; diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 09d95f309b4..d7dbc8b9928 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -6,11 +6,15 @@ template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false> static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const T * x_ptr, const float * y_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride) { + const T * GGML_CUDA_RESTRICT x = x_ptr; + const float * GGML_CUDA_RESTRICT y = y_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; @@ -80,9 +84,8 @@ static __global__ void mul_mat_vec_f( gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; } - const int channel_bias = ids ? channel_x : channel_dst; - if constexpr (has_fusion) { + const int channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; } @@ -95,7 +98,7 @@ static __global__ void mul_mat_vec_f( extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; - float * buf_iw_gate = nullptr; + [[maybe_unused]] float * buf_iw_gate = nullptr; if constexpr (has_fusion) { buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); } @@ -123,7 +126,7 @@ static __global__ void mul_mat_vec_f( if constexpr (std::is_same_v<T, float>) { const float2 * x2 = (const float2 *) x; - const float2 * gate_x2 = nullptr; + [[maybe_unused]] const float2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const float2 *) gate_x; @@ -155,7 +158,7 @@ static __global__ void mul_mat_vec_f( } } else if constexpr (std::is_same_v<T, half>) { const half2 * x2 = (const half2 *) x; - const half2 * gate_x2 = nullptr; + [[maybe_unused]] const half2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const half2 *) gate_x; @@ -266,7 +269,7 @@ static __global__ void mul_mat_vec_f( } #else const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; - const nv_bfloat162 * gate_x2 = nullptr; + [[maybe_unused]] const nv_bfloat162 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const nv_bfloat162 *) gate_x; @@ -274,7 +277,7 @@ static __global__ void mul_mat_vec_f( } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const nv_bfloat162 tmpx = x2[col2]; - nv_bfloat162 tmpx_gate; + [[maybe_unused]] nv_bfloat162 tmpx_gate; if constexpr (has_fusion) { if (use_gate) { tmpx_gate = gate_x2[col2]; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 13b8b855282..fe44a58da91 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -63,6 +63,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_TURING, MMVQ_PARAMETERS_GCN, MMVQ_PARAMETERS_RDNA2, MMVQ_PARAMETERS_RDNA3_0, @@ -78,6 +79,8 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING && __CUDA_ARCH__ < GGML_CUDA_CC_AMPERE + return MMVQ_PARAMETERS_TURING; #else return MMVQ_PARAMETERS_GENERIC; #endif @@ -96,6 +99,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { return MMVQ_PARAMETERS_GCN; } + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING && ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_AMPERE) { + return MMVQ_PARAMETERS_TURING; + } return MMVQ_PARAMETERS_GENERIC; } @@ -271,6 +277,53 @@ int get_mmvq_mmid_max_batch(ggml_type type, int cc) { return MMVQ_MAX_BATCH_SIZE; } +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + if (GGML_CUDA_CC_IS_CDNA1(cc)) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return ne11 <= 7; + case GGML_TYPE_Q5_1: + return ne11 <= 7; + case GGML_TYPE_Q8_0: + return ne11 <= 6; + case GGML_TYPE_Q2_K: + return ne11 <= 4; + case GGML_TYPE_Q3_K: + return ne11 <= 3; + case GGML_TYPE_Q4_K: + return ne11 <= 2; + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 4; + case GGML_TYPE_IQ1_S: + return ne11 <= 5; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + return ne11 <= 6; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + switch (type) { // tuned for CDNA2 + case GGML_TYPE_Q2_K: + return ne11 <= 5; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 5; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + return ne11 <= MMVQ_MAX_BATCH_SIZE; +} + // Device constexpr: returns the max batch size for the current arch+type at compile time. template <ggml_type type> static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { @@ -358,7 +411,6 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_K: return 8; case GGML_TYPE_Q6_K: return 2; @@ -370,11 +422,38 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d } return 1; } + if (table_id == MMVQ_PARAMETERS_TURING) { + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return 2; + default: + return 4; + } + } + switch (ncols_dst) { + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } return 1; } static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { - if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN || table_id == MMVQ_PARAMETERS_TURING) { switch (ncols_dst) { case 1: return small_k ? nwarps : 1; @@ -396,12 +475,16 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false> __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi; @@ -435,7 +518,7 @@ static __global__ void mul_mat_vec_q( bool use_gate = false; bool use_bias = false; bool use_gate_bias = false; - const void * vgate = nullptr; + [[maybe_unused]] const void * vgate = nullptr; const float * x_bias = nullptr; const float * gate_bias = nullptr; ggml_glu_op active_glu; @@ -451,8 +534,8 @@ static __global__ void mul_mat_vec_q( } - float x_biases[ncols_dst] = { 0.0f }; - float gate_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float x_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { @@ -509,12 +592,7 @@ static __global__ void mul_mat_vec_q( } __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - if constexpr (!has_fusion) { - (void) tmp_shared_gate; - } else if (!use_gate) { - (void) tmp_shared_gate; - } + [[maybe_unused]] __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { #pragma unroll @@ -603,12 +681,16 @@ static __global__ void mul_mat_vec_q( template <ggml_type type, int c_rows_per_block> __launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q_moe( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, - float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, + float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint32_t ncols_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi; @@ -628,6 +710,7 @@ static __global__ void mul_mat_vec_q_moe( return; } + ggml_cuda_pdl_sync(); const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); @@ -647,6 +730,8 @@ static __global__ void mul_mat_vec_q_moe( } } + ggml_cuda_pdl_lc(); + // Warp-level reduction only - no shared memory needed #pragma unroll for (int i = 0; i < c_rows_per_block; ++i) { @@ -715,8 +800,9 @@ static void mul_mat_vec_q_moe_launch( const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; const dim3 block_nums(nblocks_rows, nchannels_dst); const dim3 block_dims(warp_size, ncols_dst); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); - mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>( + ggml_cuda_kernel_launch(mul_mat_vec_q_moe<type, rows_per_block>, launch_params, vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, stride_row_x, stride_col_y, stride_col_dst, stride_channel_x, stride_channel_y, stride_channel_dst, diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 6bf0a8e8677..5605bf7a4e6 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -2,6 +2,8 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11); + // Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, // based on the quantization type and GPU architecture (compute capability). int get_mmvq_mmid_max_batch(ggml_type type, int cc); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 49516965cad..39a500a1704 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -3,10 +3,12 @@ __launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( - const float * __restrict__ x, void * __restrict__ vy, + const float * x_ptr, void * vy_ptr, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT x = x_ptr; + void * GGML_CUDA_RESTRICT vy = vy_ptr; const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 5895d3bf8e5..968c47aa20a 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -2,7 +2,9 @@ // Row reduction kernel template - compute sum (norm=false) or mean (norm=true) template <bool norm> -static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { +static __global__ void reduce_rows_f32(const float * x_ptr, float * dst_ptr, const int ncols) { + const float * GGML_CUDA_RESTRICT x = x_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; const int col = threadIdx.x; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index e14f96b824c..3b4f004c946 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -111,9 +111,9 @@ static void set_rows_cuda_quant( } template <typename src_t, typename idx_t, typename dst_t> -static __global__ void k_set_rows(const src_t * __restrict__ src0, - const idx_t * __restrict__ src1, - dst_t * __restrict__ dst, +static __global__ void k_set_rows(const src_t * src0_ptr, + const idx_t * src1_ptr, + dst_t * dst_ptr, const int64_t ne_total, const int64_t ne10, const int64_t ne11, @@ -133,6 +133,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const uint3 ne02, const uint3 ne11_fd, const uint3 ne12_fd) { + const src_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const idx_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; if (i >= ne_total) { diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 48787b4b890..1463169cf78 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,12 +3,16 @@ #include "unary.cuh" template <bool apply_silu, size_t split_d_inner, size_t d_conv> -static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, - const float * __restrict__ bias, +static __global__ void ssm_conv_f32(const float * src0_ptr, const float * src1_ptr, + const float * bias_ptr, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, - float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, + float * dst_ptr, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT bias = bias_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 412980376ac..3022249c77d 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -17,14 +17,22 @@ using namespace cub; #endif // __clang__ template <size_t splitD, size_t N, size_t L_template> __global__ void __launch_bounds__(splitD, 1) - ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, - const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + ssm_scan_f32(const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_inner, const int64_t L_param) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const size_t L = L_template == 0 ? L_param : L_template; ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); @@ -59,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1) __shared__ CubTempStorage cub_temp_storage; BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); + __syncthreads(); BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); #else const int stride_s0 = src0_nb2 / sizeof(float); @@ -97,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1) regs0[n] = state; } y_block[i * stride_y + threadIdx.x] = sumf; + __syncthreads(); } #ifdef USE_CUB @@ -118,13 +128,21 @@ __global__ void __launch_bounds__(splitD, 1) template <int c_factor, int d_state> __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( - const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; @@ -233,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); - const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); switch (n_tok) { case 1: diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index da20c9aab7c..c4253bfa43b 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -134,7 +134,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * // selection_wt is only needed when bias is present (selection uses wt + bias) // when no bias, we use wt directly for both selection and weight values - float selection_wt[has_bias ? experts_per_thread : 1]; + [[maybe_unused]] float selection_wt[has_bias ? experts_per_thread : 1]; if constexpr (has_bias) { #pragma unroll diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 5e0e22c7fc2..a6115cd80dc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -219,9 +219,9 @@ #define RDNA3 #endif // defined(__GFX11__) -#if defined(__gfx1150__) || defined(__gfx1151__) +#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) #define RDNA3_5 -#endif // defined(__gfx1150__) || defined(__gfx1151__) +#endif // defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) #if defined(RDNA3) && !defined(RDNA3_5) #define RDNA3_0 diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9db99cb0f3a..49bd7e4331a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -39,7 +39,7 @@ #include "ggml-hexagon.h" #include "ggml-impl.h" #include "ggml-quants.h" -#include "op-desc.h" +#include "htp-opnode.h" #include "htp-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -68,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches +static int opt_oppoll = 0; // polling for batch completions static std::regex* opt_opfilter = NULL; // regex of ops to not claim @@ -101,23 +102,23 @@ static const char * status_to_str(uint32_t status) { // ** debug helpers -static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) { +static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const htp_opnode & node, const uint32_t req_flags) { if (!opt_verbose) return; - op_desc desc(op); + htp_opformat fmt(node); GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; - op_desc desc(op); + htp_opformat fmt(htp_opformat(htp_opnode{const_cast<ggml_tensor*>(op), {}, HTP_OP_INVALID})); GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); + ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no"); } -static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, +static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node, uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; @@ -128,15 +129,16 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } - op_desc desc(op); + htp_opformat fmt(node); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions struct ggml_hexagon_opbatch; struct ggml_hexagon_opqueue; +struct htp_opnode; struct ggml_hexagon_session { std::string name; @@ -166,7 +168,7 @@ struct ggml_hexagon_session { void allocate(int dev_id) noexcept(false); void release() noexcept(true); - void enqueue_op(htp_op_code opcode, const ggml_tensor *op); + void enqueue_op(const htp_opnode & node); void flush(bool all = true); void flush_pending(bool all = false); @@ -550,7 +552,7 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -611,7 +613,7 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -660,6 +662,239 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) ggml_aligned_free(buf_rp, row_size_rp); } +static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + bool partial = (nloe && i == nb-1); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; + } + } + + // Repack the scales and offsets + for (int i = 0; i < nb; i++) { + ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + d_m[j * 2 + 0] = x[i * 8 + j].d; + d_m[j * 2 + 1] = x[i * 8 + j].m; + } + } +} + +static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; + bool partial = (nloe && i == nb-1); + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + if (partial) { + qs[j*2+0] = q[j] & 0x0F; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0x0F; + qs[j+128] = q[j] >> 4; + } + } + + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + // Unpack the scales and offsets + for (int i = 0; i < nb; i++) { + const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = d_m[j * 2 + 0]; + x[i * 8 + j].m = d_m[j * 2 + 1]; + } + } +} + +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = 0; + x[i * 8 + j].m = 0; + } + } +} + +static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + // We still need to read and unpack the entire source row because quantization is block-based. + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + // ======== Q8x4x2 ==================== static void dump_block_q8_0(const block_q8_0 * b, int i) { HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], @@ -876,7 +1111,7 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -937,7 +1172,7 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1238,7 +1473,7 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -1299,7 +1534,7 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1365,6 +1600,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1407,6 +1648,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_1(data, tensor, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1536,12 +1783,10 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; -// Backend session implementation - struct ggml_hexagon_opbatch { ggml_hexagon_session* sess; - std::vector<const ggml_tensor*> ops; // pointers to original ops + std::vector<htp_opnode> ops; // htp_opnode of ops std::vector<htp_buf_desc> h_bufs; // htp buffer descriptors std::vector<htp_tensor> h_tens; // htp tensor descriptors @@ -1673,7 +1918,7 @@ struct ggml_hexagon_opbatch { return ti; } - bool fit_op(const struct ggml_tensor *t) const { + bool fit_op(const htp_opnode & node) const { if (n_ops >= n_ops_max ) return false; // check how much extras we will need @@ -1682,6 +1927,7 @@ struct ggml_hexagon_opbatch { size_t extra_tens = 0; auto fit_tensor = [&](const ggml_tensor *t) { + if (!t) return; if (!t_map.count(t)) { extra_tens++; @@ -1693,10 +1939,10 @@ struct ggml_hexagon_opbatch { } }; - for (unsigned int i=0; i < HTP_OP_MAX_INPUTS && t->src[i]; i++) { - fit_tensor(t->src[i]); + for (const auto * src : node.get_inputs()) { + fit_tensor(src); } - fit_tensor(t); + fit_tensor(node.dst()); if ((extra_bufs + n_bufs) > n_bufs_max) return false; if ((extra_tens + n_tens) > n_tens_max) return false; @@ -1706,29 +1952,30 @@ struct ggml_hexagon_opbatch { } // assumes that fit_op() was called first and returned true - void add_op(htp_op_code opcode, const struct ggml_tensor * t) { + void add_op(const htp_opnode & node) { // Add new op unsigned int n = n_ops++; GGML_ASSERT(n_ops <= n_ops_max); - ops[n] = t; + ops[n] = node; htp_op_desc &o = h_ops[n]; - memcpy(&o.params, &t->op_params, sizeof(t->op_params)); - o.opcode = opcode; + memcpy(&o.params, &node.node->op_params, sizeof(node.node->op_params)); + o.opcode = node.opcode; o.flags = 0; if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(sess->c_name(), t, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), node, o.flags); + auto inputs = node.get_inputs(); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { - o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; + o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; } - o.dst = add_tensor(t); + o.dst = add_tensor(node.dst()); } }; @@ -1737,7 +1984,7 @@ struct ggml_hexagon_opqueue { ggml_hexagon_shared_buffer *shm_buf; size_t shm_blk_size; - using opvec = std::vector<const ggml_tensor*>; + using opvec = std::vector<htp_opnode>; std::queue<unsigned int> done; // completed batch ids std::vector<opvec> op_cache; // per batch op cache @@ -1886,7 +2133,8 @@ void ggml_hexagon_session::flush_pending(bool all) { uint32_t n_dbufs; // Read response packet from queue - int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT); + const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); if (err == AEE_EEXPIRED) { continue; } @@ -1935,11 +2183,11 @@ void ggml_hexagon_session::flush_batch() { } } -void ggml_hexagon_session::enqueue_op(htp_op_code opcode, const ggml_tensor *op) { - if (!op_batch->fit_op(op)) { +void ggml_hexagon_session::enqueue_op(const htp_opnode & node) { + if (!op_batch->fit_op(node)) { flush_batch(); } - op_batch->add_op(opcode, op); + op_batch->add_op(node); } // Flush HTP response queue i.e wait for all outstanding requests to complete @@ -2290,6 +2538,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses const int64_t H = v->ne[1]; const int64_t n_tokens = v->ne[2]; const int64_t n_seqs = v->ne[3]; + const int64_t K = ggml_get_op_params_i32(op, 0); if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { return false; @@ -2302,10 +2551,11 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } + // state holds s0 only [S_v, S_v, H, n_seqs]; K is op param 0. if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { return false; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return false; } @@ -2327,6 +2577,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2353,6 +2604,27 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } + break; + + case GGML_TYPE_F32: + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (src0->nb[1] < src0->nb[0]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); + return false; + } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } if (ggml_nrows(src1) > 1024) { return false; // no huge batches (for now) } @@ -2377,6 +2649,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2874,6 +3147,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_CONCAT: return HTP_OP_CONCAT; case GGML_OP_SCALE: return HTP_OP_SCALE; case GGML_OP_SQR: return HTP_OP_SQR; case GGML_OP_SQRT: return HTP_OP_SQRT; @@ -2891,13 +3165,14 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { - case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; - case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; - case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; - case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; - case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; - case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; - case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; + case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; + case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_GELU_QUICK: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; + case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; + case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; + case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; default: break; } @@ -2928,10 +3203,43 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); + std::vector<htp_opnode> nodes; + nodes.reserve(graph->n_nodes); + + // Fusion for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * n = graph->nodes[i]; - if (op_is_compute(n) && (opt_opstage & HTP_OPSTAGE_QUEUE)) { - sess->enqueue_op(op_remap_to_htp(n), n); + if (!op_is_compute(n)) { + continue; + } + + ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + + htp_opnode node = { + /*.node =*/ n, + /*.fused =*/ {}, + /*.opcode =*/ HTP_OP_INVALID + }; + + if (n->op == GGML_OP_RMS_NORM && next_node) { + if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + node.add_fused(next_node); + node.opcode = HTP_OP_RMS_NORM_MUL; + i++; // skip the fused MUL node + } + } + + if (node.opcode == HTP_OP_INVALID) { + node.opcode = op_remap_to_htp(n); + } + + nodes.push_back(std::move(node)); + } + + // Queue and execute + if (opt_opstage & HTP_OPSTAGE_QUEUE) { + for (const auto & node : nodes) { + sess->enqueue_op(node); } } @@ -2950,51 +3258,7 @@ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { sess->flush(); } -struct node_info { - ggml_tensor * node; - - std::vector<ggml_tensor *> fused; - - ggml_op op() const { - return node->op; - } - - const ggml_tensor * dst() const { - return fused.empty() ? node : fused.back(); - } - - const ggml_tensor * src0() const { - return node->src[0]; - } - - const ggml_tensor * src1() const { - return node->src[1]; - } - - bool is_empty() const { - return ggml_op_is_empty(node->op); - } - - void add_fused(ggml_tensor * t) { - fused.push_back(t); - } - - bool stackable() const { - switch (this->op()) { - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - return ggml_is_quantized(this->src0()->type); - default: - return false; - } - } - - bool same_input(const node_info& n) const { - return n.src1() == this->src1(); - } -}; - -static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) { +static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<htp_opnode> & nodes) { const int n = nodes.size(); std::vector<int> res; @@ -3048,14 +3312,14 @@ static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgr enum ggml_op ops[MAX_FUSE]; - std::vector<node_info> nodes; + std::vector<htp_opnode> nodes; nodes.reserve(gf->n_nodes); // fuse nodes: // we don't want to make reorders that break fusing, so we first pack all fusable tensors // and perform the reorder over the fused nodes. after the reorder is done, we unfuse for (int i = 0; i < n; i++) { - node_info node = { + htp_opnode node = { /*.node =*/gf->nodes[i], /*.fused =*/{}, }; @@ -3286,6 +3550,25 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_concat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + int dim = ((const int32_t *) op->op_params)[0]; + if (dim < 0 || dim >= GGML_MAX_DIMS) { + return false; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + const struct ggml_tensor * src = op->src[i]; + if (!src) { + continue; + } + if (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_I32 && src->type != GGML_TYPE_F16) { + return false; + } + } + + return true; +} + static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * dst = op; @@ -3371,6 +3654,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: supp = ggml_hexagon_supported_activations(sess, op); break; default: @@ -3434,6 +3718,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_CONCAT: + supp = ggml_hexagon_supported_concat(sess, op); + break; + case GGML_OP_FILL: supp = ggml_hexagon_supported_fill(sess, op); break; @@ -3598,6 +3886,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, @@ -3610,6 +3900,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); @@ -3647,6 +3938,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h new file mode 100644 index 00000000000..52c727c6206 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -0,0 +1,272 @@ +#ifndef HTP_OPNODE_H +#define HTP_OPNODE_H + +#define GGML_COMMON_IMPL_CPP +#include "ggml-backend-impl.h" +#include "ggml-common.h" + +#include <string> +#include <vector> +#include <stdio.h> +#include "htp-ops.h" + +struct htp_opnode { + ggml_tensor * node = nullptr; + + std::vector<ggml_tensor *> fused; + + htp_op_code opcode = HTP_OP_INVALID; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + const ggml_tensor * src0() const { + return node->src[0]; + } + + const ggml_tensor * src1() const { + return node->src[1]; + } + + bool is_empty() const { + return ggml_op_is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } + + bool stackable() const { + switch (this->op()) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return ggml_is_quantized(this->src0()->type); + default: + return false; + } + } + + bool same_input(const htp_opnode& n) const { + return n.src1() == this->src1(); + } + + std::vector<const ggml_tensor *> get_inputs() const { + if (fused.empty()) { + int last_non_null = -1; + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + last_non_null = i; + } + } + std::vector<const ggml_tensor *> inputs(last_non_null + 1, nullptr); + for (int i = 0; i <= last_non_null; i++) { + inputs[i] = node->src[i]; + } + return inputs; + } + + std::vector<const ggml_tensor *> inputs(GGML_MAX_SRC, nullptr); + std::vector<const ggml_tensor *> outputs; + outputs.push_back(node); + for (const auto * f : fused) { + outputs.push_back(f); + } + + auto contains = [&](const std::vector<const ggml_tensor *> & vec, const ggml_tensor * t) { + for (const auto * x : vec) { + if (x == t) return true; + } + return false; + }; + + int count = 0; + auto add_input = [&](const ggml_tensor * t) { + if (t && !contains(outputs, t) && !contains(inputs, t)) { + if (count < (int)inputs.size()) { + inputs[count++] = t; + } else { + inputs.push_back(t); + } + } + }; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + add_input(node->src[i]); + } + } + for (const auto * f : fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (f->src[i]) { + add_input(f->src[i]); + } + } + } + + inputs.resize(count); + return inputs; + } + + std::string op_name() const { + if (fused.empty()) { + return ggml_op_desc(node); + } + std::string name = ggml_op_desc(node); + for (const auto * f : fused) { + name += "+"; + name += ggml_op_desc(f); + } + return name; + } +}; + +struct htp_opformat { + char strides[64 * GGML_MAX_SRC]; + char dims[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + } else { + return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + } + } + + void format_op_dims(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_dims(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_dims(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_dims(self, node.dst()); + p += sprintf(p, "%s", self); + } + + int format_tensor_strides(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } + const char * c = ggml_is_contiguous(t) ? "" : "!"; + + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + } else { + return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + } + } + + void format_op_strides(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_strides(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_strides(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_strides(self, node.dst()); + p += sprintf(p, "%s", self); + } + + void format_op_types(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", ggml_type_name(node.dst()->type)); + } + + const char * tensor_buff_name(const struct ggml_tensor * t) { + if (t && t->buffer) { + return ggml_backend_buffer_name(t->buffer); + } + return "NONE"; + } + + void format_op_buffs(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", tensor_buff_name(inputs[0])); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", tensor_buff_name(inputs[i])); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", tensor_buff_name(node.dst())); + } + + void format_op_names(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", node.dst()->name); + } + + void format(const htp_opnode & node) { + format_op_dims(dims, node); + format_op_strides(strides, node); + format_op_types(types, node); + format_op_buffs(buffs, node); + format_op_names(names, node); + } + + htp_opformat() {} + htp_opformat(const htp_opnode & node) { format(node); } +}; + +#endif // HTP_OPNODE_H diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 36f923243cd..f4b44fe1a65 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,26 +19,6 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c - matmul-ops.c - binary-ops.c - unary-ops.c - sum-rows-ops.c - softmax-ops.c - act-ops.c - rope-ops.c - flash-attn-ops.c - set-rows-ops.c - get-rows-ops.c - cpy-ops.c - repeat-ops.c - argsort-ops.c - ssm-conv.c - cumsum-ops.c - fill-ops.c - diag-ops.c - solve-tri-ops.c - gated-delta-net-ops.c - pad-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE @@ -57,15 +37,16 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE - hmx-queue.c hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-queue.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( - hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-matmul-ops.c + hmx-queue.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) @@ -74,6 +55,30 @@ endif() build_idl(htp_iface.idl ${HTP_LIB}) +target_sources(${HTP_LIB} PRIVATE + matmul-ops.c + binary-ops.c + unary-ops.c + sum-rows-ops.c + softmax-ops.c + act-ops.c + rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c + cpy-ops.c + repeat-ops.c + argsort-ops.c + ssm-conv.c + cumsum-ops.c + fill-ops.c + concat-ops.c + diag-ops.c + solve-tri-ops.c + gated-delta-net-ops.c + pad-ops.c +) + set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index bdd0623615d..73af38a35ab 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -276,6 +276,7 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.size = total_spad_size; octx->src0_spad.size_per_thread = spad_per_thread; + octx->src0_spad.src = NULL; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c new file mode 100644 index 00000000000..f2a381313c5 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -0,0 +1,277 @@ +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hexagon_types.h" +#include "hexagon_protos.h" +#include "hvx_hexagon_protos.h" +#include "hex-dma.h" +#include "vtcm-utils.h" +#include "hvx-utils.h" +#include "hex-fastdiv.h" +#include <string.h> + +struct htp_concat_context { + struct htp_ops_context * octx; + uint32_t dim; + uint32_t nrows_per_thread; + struct fastdiv_values div_ne0; + struct fastdiv_values div_ne1; + struct fastdiv_values div_ne2; +}; + +static void concat_2d_f32_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 32; + const uint32_t spad1_stride = block_i * sizeof(float); + + int32_t offsets[32] __attribute__((aligned(128))); + for(int k=0; k<32; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 32); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(float), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(float); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(float); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 32) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(float)); + Q6_vgather_ARMVw(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(float); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(float), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_2d_f16_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 64; + const uint32_t spad1_stride = block_i * sizeof(__fp16); + + int16_t offsets[64] __attribute__((aligned(128))); + for(int k=0; k<64; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 64); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(__fp16), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(__fp16); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(__fp16); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 64) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(__fp16)); + Q6_vgather_ARMVh(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(__fp16); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(__fp16), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_generic(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const int dim = cctx->dim; + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + + const uint32_t ne[4] = {dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]}; + const uint32_t total_elements = ne[0] * ne[1] * ne[2] * ne[3]; + const uint32_t chunk_size = (total_elements + nth - 1) / nth; + + const uint32_t start_idx = MIN(ith * chunk_size, total_elements); + const uint32_t end_idx = MIN(start_idx + chunk_size, total_elements); + + // Naive scalar element-wise copy + for (uint32_t idx = start_idx; idx < end_idx; idx++) { + uint32_t idx_div_ne0 = fastdiv(idx, &cctx->div_ne0); + uint32_t i0 = idx - idx_div_ne0 * ne[0]; + + uint32_t idx_div_ne01 = fastdiv(idx_div_ne0, &cctx->div_ne1); + uint32_t i1 = idx_div_ne0 - idx_div_ne01 * ne[1]; + + uint32_t idx_div_ne012 = fastdiv(idx_div_ne01, &cctx->div_ne2); + uint32_t i2 = idx_div_ne01 - idx_div_ne012 * ne[2]; + uint32_t i3 = idx_div_ne012; + + uint8_t * dst_ptr = (uint8_t *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2] + i1 * dst->nb[1] + i0 * dst->nb[0]; + + uint32_t idx_dim = 0; + if (dim == 0) idx_dim = i0; + else if (dim == 1) idx_dim = i1; + else if (dim == 2) idx_dim = i2; + else if (dim == 3) idx_dim = i3; + + const struct htp_tensor * src = (idx_dim < src0->ne[dim]) ? src0 : src1; + + uint32_t s0 = i0; + uint32_t s1 = i1; + uint32_t s2 = i2; + uint32_t s3 = i3; + + if (dim == 0 && src == src1) s0 -= src0->ne[0]; + if (dim == 1 && src == src1) s1 -= src0->ne[1]; + if (dim == 2 && src == src1) s2 -= src0->ne[2]; + if (dim == 3 && src == src1) s3 -= src0->ne[3]; + + uint8_t * src_ptr = (uint8_t *)src->data + s3 * src->nb[3] + s2 * src->nb[2] + s1 * src->nb[1] + s0 * src->nb[0]; + + if (type_size == 4) { + *(float*)dst_ptr = *(float*)src_ptr; + } else { + *(__fp16*)dst_ptr = *(__fp16*)src_ptr; + } + } +} + +int op_concat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + int dim = octx->op_params[0]; + + bool is_2d = dst->ne[2] == 1 && dst->ne[3] == 1; + + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + bool is_src1_transposed = (src1->nb[0] > src1->nb[1]); + bool is_src0_transposed = (src0->nb[0] > src0->nb[1]); + + uint32_t n_threads = octx->n_threads; + struct htp_concat_context cctx; + cctx.octx = octx; + cctx.dim = dim; + cctx.div_ne0 = init_fastdiv_values(dst->ne[0]); + cctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + cctx.div_ne2 = init_fastdiv_values(dst->ne[2]); + + void (*worker_func)(unsigned int, unsigned int, void *) = concat_generic; + + if (dim == 0 && is_2d && is_src1_transposed && !is_src0_transposed) { + n_threads = MIN(dst->ne[1], n_threads); + if (n_threads < 1) { + n_threads = 1; + } + uint32_t block_i = (type_size == 4) ? 32 : 64; + + cctx.nrows_per_thread = hmx_ceil_div(dst->ne[1], n_threads); + + // Allocate VTCM + uint32_t spad1_stride = block_i * type_size; + + uint32_t src1_ne0_padded = hex_round_up(src1->ne[0], block_i); + uint32_t spad0_row_bytes = hex_round_up((src0->ne[0] + src1_ne0_padded) * type_size, VLEN); + + octx->src0_spad.size_per_thread = block_i * spad0_row_bytes; + octx->src1_spad.size_per_thread = src1_ne0_padded * spad1_stride + block_i * VLEN; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + + if (octx->src0_spad.size + octx->src1_spad.size > octx->ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + + if (type_size == 4) { + worker_func = concat_2d_f32_transposed; + } else { + worker_func = concat_2d_f16_transposed; + } + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &cctx, n_threads); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index 5c040a32224..ae507effa51 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -28,158 +28,170 @@ struct htp_copy_context { uint32_t dst_blocks_per_row; uint32_t src0_nrows_per_thread; - - void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); }; #define cpy_preamble \ const struct htp_tensor *src0 = octx->src[0]; \ const struct htp_tensor *dst = octx->dst; \ \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ const uint32_t nr = ne01; -static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { - cpy_preamble; - - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // copy by rows - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - #pragma unroll(2) - for (uint32_t i01 = ir0; i01 < ir1; i01++) { - uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; - uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2); - hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size); - } - } - } +#define DEFINE_CPY_SAMESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_sameshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + for (uint32_t i03 = 0; i03 < ne03; i03++) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + _Pragma("unroll(4)") \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; \ + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; \ + hex_l2fetch(src0_ptr, ne00 * ELEM_SIZE, nb01, 2); \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + } \ } -static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) { - cpy_preamble; - - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // Fast path: when both src0 and dst are contiguous in memory - // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. - const bool src0_contig = (nb00 == ct->src0_type_size) && - (nb01 == ne00 * nb00) && - (nb02 == ne01 * nb01) && - (nb03 == ne02 * nb02); - const bool dst_contig = (nb0 == ct->dst_type_size) && - (nb1 == ne0 * nb0) && - (nb2 == ne1 * nb1) && - (nb3 == ne2 * nb2); - - if (src0_contig && dst_contig) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; - uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; - uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; - hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); - } - } - return; - } - - // dst counters - int64_t k10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - // number of blocks in a row - const int64_t nk00 = ct->src0_blocks_per_row; - const int64_t nk0 = ct->dst_blocks_per_row; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - k10 += nk00 * ir0; - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t k00 = 0; k00 < nk00; k00++) { - const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - memcpy(dst_ptr, src0_ptr, ct->dst_type_size); - - if (++k10 == nk0) { - k10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - k10 += nk00 * (ne01 - ir1); - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } +DEFINE_CPY_SAMESHAPE(f32, float, 4) +DEFINE_CPY_SAMESHAPE(f16, __fp16, 2) + +#define DEFINE_CPY_RESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_reshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + const bool src0_contig = (nb00 == ELEM_SIZE) && \ + (nb01 == ne00 * nb00) && \ + (nb02 == ne01 * nb01) && \ + (nb03 == ne02 * nb02); \ + const bool dst_contig = (nb0 == ELEM_SIZE) && \ + (nb1 == ne0 * nb0) && \ + (nb2 == ne1 * nb1) && \ + (nb3 == ne2 * nb2); \ + if (src0_contig && dst_contig) { \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; \ + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ELEM_SIZE; \ + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + const bool reshape_flat_fast = (ne03 == 1 && ne2 == 1 && ne3 == 1) && \ + (ne0 == ne00 * ne01) && (ne1 == ne02) && \ + (nb00 == ELEM_SIZE) && (nb0 == ELEM_SIZE); \ + if (reshape_flat_fast) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t * src0_ptr = (uint8_t *) src0->data + i01 * nb01 + i02 * nb02; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + i01 * ne00 * ELEM_SIZE + i02 * nb1; \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + int64_t k10 = 0; \ + int64_t i11 = 0; \ + int64_t i12 = 0; \ + int64_t i13 = 0; \ + const int64_t nk00 = ct->src0_blocks_per_row; \ + const int64_t nk0 = ct->dst_blocks_per_row; \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + k10 += nk00 * ir0; \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + for (int64_t i01 = ir0; i01 < ir1; i01++) { \ + for (int64_t k00 = 0; k00 < nk00; k00++) { \ + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); \ + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); \ + memcpy(dst_ptr, src0_ptr, ELEM_SIZE); \ + if (++k10 == nk0) { \ + k10 = 0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ + k10 += nk00 * (ne01 - ir1); \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ } -static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +DEFINE_CPY_RESHAPE(f32, float, 4) +DEFINE_CPY_RESHAPE(f16, __fp16, 2) + +static void cpy_thread_f16_f32_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -195,13 +207,16 @@ static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +static void cpy_thread_f32_f16_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -217,11 +232,6 @@ static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_work_func(unsigned int n, unsigned int i, void *data) { - struct htp_copy_context *ct = (struct htp_copy_context *) data; - ct->copy(ct, ct->octx, n, i); -} - int op_cpy(struct htp_ops_context * octx) { cpy_preamble; @@ -254,22 +264,32 @@ int op_cpy(struct htp_ops_context * octx) { ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; + worker_callback_t copy_fun; + if (sametype && sameshape) { - ct.copy = cpy_thread_sametype_sameshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_sameshape; + } else { + copy_fun = cpy_thread_f16_sameshape; + } } else if (sameshape) { /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32) - ct.copy = cpy_thread_f16_f32_sameshape; + copy_fun = cpy_thread_f16_f32_sameshape; else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16) - ct.copy = cpy_thread_f32_f16_sameshape; + copy_fun = cpy_thread_f32_f16_sameshape; else return HTP_STATUS_NO_SUPPORT; } else if (sametype) { - ct.copy = cpy_thread_sametype_reshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_reshape; + } else { + copy_fun = cpy_thread_f16_reshape; + } } else { return HTP_STATUS_NO_SUPPORT; } - worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads); + worker_pool_run_func(octx->ctx->worker_pool, copy_fun, &ct, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index d95df6ac9d5..e996214691a 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" #include "hvx-dump.h" +#include "hvx-flash-attn.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -22,6 +23,16 @@ // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + // This is a bit of a hack because the compiler is strugling to properly inline // the default hvx_vec_f32_to_f16 with output into the local array. static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) @@ -54,8 +65,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum))); + HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)); + rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); hvx_vec_store_u(r, 4, rsum); } @@ -105,10 +116,10 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y, rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } - HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); - HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); - HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p))); - HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p))); + HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)); + HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)); + HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)); + HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)); HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } }; return hvx_vec_reduce_sum_f32x4(rsum0123); @@ -123,7 +134,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors const size_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector sums; // initialize at j = 0 + HVX_Vector sums = Q6_V_vzero(); const size_t stride_x_4 = stride_x * 4; for (uint32_t j = 0; j < VLEN_FP32; j += 4) { HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe); @@ -132,8 +143,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, x += stride_x_4; } - sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums); - return Q6_Vsf_equals_Vqf32(sums); + return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums); } // MAD: y (F32) += x (F16) * s (F16) @@ -236,6 +246,7 @@ struct htp_fa_context { uint32_t n_head_log2; float m0; float m1; + float slopes[512]; uint32_t n_blocks; @@ -268,11 +279,10 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * uint32_t i = 0; #pragma unroll(4) for (; i < nvec; ++i) { - vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs)); + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs); } if (nloe) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v)); + hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs)); } } @@ -404,7 +414,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + const float slope = factx->slopes[h]; HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); @@ -438,25 +448,44 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Process in sub-blocks of 32 (VLEN_FP32) HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32]; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); - for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { + for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale); // 2. Softcap if (factx->logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); - scores = Q6_Vsf_equals_Vqf32(scores); + scores = HVX_OP_MUL_F32(scores, logit_cap); } // 3. Mask if (mask) { const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); - HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); - scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); - scores = Q6_Vsf_equals_Vqf32(scores); + + // Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79. + // Clamp -INFINITY to the max negative fp16 finite value (-65504.0f). + HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00); + HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF); + HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf); + m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16); + + #if __HVX_ARCH__ >= 79 + HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_vadd_VsfVsf(add_val, scores); + #else + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores)); + #endif + } + + // Mask out invalid lanes for leftover handling + uint32_t valid_lanes = current_block_size - ic; + if (valid_lanes < VLEN_FP32) { + HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane + scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY)); } sb_scores[iv] = scores; @@ -466,78 +495,55 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * { // 4. Online Softmax Update HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); - HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec)); + HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec); HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); M_vec = M_new_vec; hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); - for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) { HVX_Vector scores = sb_scores[iv]; - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); - HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec); + HVX_Vector P = hvx_vec_exp_f32(scores_shifted); - p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); + p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P); // 5. Accumulate V __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16]; hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0)); + float __attribute__((aligned(128))) P_arr[VLEN_FP32]; + hvx_vec_store_a(P_arr, 128, P); + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { - const uint32_t cur_ic = ic2 + j; - const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + const uint32_t cur_ic = ic2 + j; + if (cur_ic >= current_block_size) { + break; + } + + if (cur_ic + 1 == current_block_size) { + // Odd leftover, process single row + if (P_arr[j] != 0.0f) { + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV); + } + break; + } + + // Avoid NaN * 0.0 = NaN for uninitialized V cache rows. + // Check the f32 values to safely avoid strict aliasing violations. + if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) { + continue; + } + + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV); } } p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); - S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); - } - - if (ic < current_block_size) { - // Sync scalars for leftover/next block if needed - float M = hvx_vec_get_f32(M_vec); - float S = hvx_vec_get_f32(S_vec); - - // Leftover - for (; ic < current_block_size; ++ic) { - float s_val; - const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); - if (factx->logit_softcap != 0.0f) { - s_val = factx->logit_softcap * tanhf(s_val); - } - - if (mask) { - const float m_val = m_base[ic]; - s_val += slope * m_val; - } - - const float Mold = M; - __fp16 vs = 1.0f; - - if (s_val > M) { - M = s_val; - HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); - HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); - hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - - float ms = hvx_vec_get_f32(ms_vec); - S = S * ms + vs; - } else { - HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); - vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); - S += vs; - } - - const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; - - hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV); - } - - M_vec = hvx_vec_splat_f32(M); - S_vec = hvx_vec_splat_f32(S); + S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec); } // Issue DMA for next+1 block (if exists) @@ -599,8 +605,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const int i2 = iq2; const int i3 = iq3; - // dst is permuted - uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + // dst is permuted: [DV, n_heads, n_tokens, n_seq] + // head stride is nb[1], token stride is nb[2], batch stride is nb[3] + uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3]; if (dst->type == HTP_TYPE_F32) { hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); @@ -623,8 +630,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } #ifdef HTP_HAS_HMX - // HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV - if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) { + // HMX path: head_dim multiple of 64, F16 KV, and no sinks + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); if (ret == HTP_STATUS_OK) { return ret; @@ -684,6 +691,13 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + if (n_head > 512) { + return HTP_STATUS_NO_SUPPORT; + } + for (uint32_t h = 0; h < n_head; ++h) { + factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f; + } + // total rows in q const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index 2e84badc9b7..35518e6111c 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -3,6 +3,7 @@ #include <string.h> #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -14,106 +15,103 @@ #define HTP_GDN_MAX_SV 128 + struct htp_gdn_context { struct htp_ops_context * octx; uint32_t rows_per_thread; - size_t state_bytes; - bool use_vtcm; - uint8_t * vtcm_state_base; - size_t vtcm_state_per_thread; + size_t state_bytes; + uint8_t * vtcm_base; + size_t vtcm_per_thread; }; -static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); const HVX_Vector vmul = hvx_vec_splat_f32(mul); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vd = hvx_vmemu(dst + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, - float scale, const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + HVX_Vector vscale, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const HVX_Vector vscale = hvx_vec_splat_f32(scale); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, @@ -126,7 +124,7 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -147,11 +145,11 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -159,10 +157,10 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -185,7 +183,7 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -205,10 +203,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -216,10 +214,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -246,7 +244,7 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -267,11 +265,11 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -279,10 +277,10 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -310,7 +308,7 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -343,11 +341,11 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -359,14 +357,14 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -400,7 +398,7 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -432,10 +430,10 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -447,14 +445,14 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -496,7 +494,7 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -529,11 +527,11 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -545,14 +543,14 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -586,6 +584,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = octx->op_params[0]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -604,22 +603,66 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[4] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); + + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + int curr_spad_idx = 0; for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + dma_queue_pop(dma); + dma_queue_pop(dma); - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + float * s_work_curr = s_work[curr_spad_idx]; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); - memcpy(s_out, s_in, gctx->state_bytes); - float * s_work = s_out; + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); + + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -635,65 +678,159 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); + } + } + + if (K > 1) { + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t; + if (target_slot >= 0 && target_slot < (int64_t) K) { + float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + if (curr_state_o != s_out) { + hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v); + } } } attn_data += (uint64_t) S_v * H; } + + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; struct htp_ops_context * octx = gctx->octx; @@ -727,38 +864,65 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[8] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); dma_queue * dma = octx->ctx->dma[ith]; - - uint8_t * spad = NULL; - if (gctx->use_vtcm) { - spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); + + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); + + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; } + int curr_spad_idx = 0; for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + dma_queue_pop(dma); + dma_queue_pop(dma); - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + float * s_work_curr = s_work[curr_spad_idx]; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - float * s_work; + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); - if (spad) { - dma_queue_push(dma, dma_make_ptr(spad, s_in), - S_v * sizeof(float), S_v * sizeof(float), - S_v * sizeof(float), S_v); - dma_queue_pop(dma); - s_work = (float *) spad; - } else { - s_work = s_out; - memcpy(s_work, s_in, gctx->state_bytes); - } + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); + + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -773,111 +937,145 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } - if (spad) { - dma_queue_push(dma, dma_make_ptr(s_out, spad), + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), S_v * sizeof(float), S_v * sizeof(float), S_v * sizeof(float), S_v); - dma_queue_pop(dma); + + ir_prefetch += nth; + spad_idx ^= 1; } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + int op_gated_delta_net(struct htp_ops_context * octx) { const struct htp_tensor * q = octx->src[0]; const struct htp_tensor * k = octx->src[1]; @@ -901,6 +1099,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) { const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = octx->op_params[0]; if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { return HTP_STATUS_NO_SUPPORT; @@ -913,10 +1112,11 @@ int op_gated_delta_net(struct htp_ops_context * octx) { (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return HTP_STATUS_NO_SUPPORT; } - if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + // state holds s0 only: [S_v, S_v, H, n_seqs] + if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) { return HTP_STATUS_NO_SUPPORT; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return HTP_STATUS_NO_SUPPORT; } @@ -932,18 +1132,11 @@ int op_gated_delta_net(struct htp_ops_context * octx) { size_t state_aligned = (size_t) S_v * S_v * sizeof(float); state_aligned = (state_aligned + 127) & ~(size_t)127; - gctx.use_vtcm = false; - gctx.vtcm_state_base = NULL; - gctx.vtcm_state_per_thread = 0; + assert(octx->ctx->vtcm_base != NULL); + assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads); - if (n_tokens == 1 && octx->ctx->vtcm_base) { - size_t vtcm_total = state_aligned * octx->n_threads; - if (octx->ctx->vtcm_size >= vtcm_total) { - gctx.use_vtcm = true; - gctx.vtcm_state_base = octx->ctx->vtcm_base; - gctx.vtcm_state_per_thread = state_aligned; - } - } + gctx.vtcm_base = octx->ctx->vtcm_base; + gctx.vtcm_per_thread = 2 * state_aligned; if (n_tokens == 1) { worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 5a1dc933860..bf7063e9880 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -17,9 +17,13 @@ struct get_rows_context { struct htp_ops_context * octx; - uint32_t src1_nrows_per_thread; + uint32_t tasks_per_thread; + uint32_t total_tasks; + uint32_t chunks_per_row; + uint32_t chunk_size; struct fastdiv_values get_rows_div_ne10; struct fastdiv_values get_rows_div_ne10_ne11; + struct fastdiv_values get_rows_div_chunks_per_row; }; #define get_rows_preamble \ @@ -52,20 +56,23 @@ struct get_rows_context { \ const uint32_t nr = ne10 * ne11 * ne12; -static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { +static void get_rows_thread_f32_f32_dma(unsigned int nth, unsigned int ith, void *data) { struct get_rows_context * grctx = (struct get_rows_context *)data; struct htp_ops_context * octx = grctx->octx; get_rows_preamble; uint64_t qt = HAP_perf_get_qtimer_count(); - // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = grctx->src1_nrows_per_thread; + const uint32_t dr = grctx->tasks_per_thread; const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + dma_queue * dma_queue = octx->ctx->dma[ith]; for (uint32_t i = ir0; i < ir1; ++i) { const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; @@ -73,28 +80,76 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i10 = rem - i11 * ne10; const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; - uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i01 >= ne01) { - // invalid index, skip for now to avoid crash continue; } const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; - hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + + while (!dma_queue_push(dma_queue, dma_make_ptr((void *)dst_ptr, (const void *)src0_ptr), nb1, nb01, ne00 * sizeof(float), 1)) { + dma_queue_pop(dma_queue); + } } + dma_queue_flush(dma_queue); qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); - FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + FARF(HIGH, "get-rows-f32-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } -int op_get_rows(struct htp_ops_context * octx) { +static void get_rows_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; get_rows_preamble; - const uint32_t n_threads = MIN(nr, octx->n_threads); + uint64_t qt = HAP_perf_get_qtimer_count(); + + const uint32_t dr = grctx->tasks_per_thread; + const uint32_t ir0 = dr * ith; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); + + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + + const uint32_t chunks_per_row = grctx->chunks_per_row; + const uint32_t chunk_size = grctx->chunk_size; + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t row_idx = fastdiv(i, &grctx->get_rows_div_chunks_per_row); + const uint32_t chunk_idx = i - row_idx * chunks_per_row; + + const uint32_t i12 = fastdiv(row_idx, &grctx->get_rows_div_ne10_ne11); + const uint32_t rem = row_idx - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + continue; + } + + const uint32_t offset = chunk_idx * chunk_size; + if (offset < ne00) { + const uint32_t copy_size = MIN(chunk_size, ne00 - offset); + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03 + offset * sizeof(float); + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3 + offset * sizeof(float); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, copy_size); + } + } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; @@ -112,13 +167,52 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } + const uint32_t nb00 = octx->src[0]->nb[0]; + const uint32_t nb0 = octx->dst->nb[0]; + + const bool can_use_dma = (nb00 == sizeof(float)) && (nb0 == sizeof(float)); + const bool use_dma = can_use_dma && (ne00 >= 2048); + struct get_rows_context grctx; grctx.octx = octx; grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); - grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; + if (use_dma) { + grctx.chunks_per_row = 1; + grctx.chunk_size = ne00; + grctx.total_tasks = nr; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(1); + + const uint32_t n_threads = MIN(nr, octx->n_threads); + grctx.tasks_per_thread = (nr + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_dma, &grctx, n_threads); + } else { + uint32_t chunks_per_row = 1; + uint32_t chunk_size = ne00; + uint32_t total_tasks = nr; + + if (nr < octx->n_threads) { + const uint32_t min_chunk_size = 1024; + uint32_t max_chunks = ne00 / min_chunk_size; + if (max_chunks == 0) { + max_chunks = 1; + } + chunks_per_row = MIN((octx->n_threads + nr - 1) / nr, max_chunks); + chunk_size = (ne00 + chunks_per_row - 1) / chunks_per_row; + total_tasks = nr * chunks_per_row; + } + + grctx.chunks_per_row = chunks_per_row; + grctx.chunk_size = chunk_size; + grctx.total_tasks = total_tasks; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(chunks_per_row); - worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads); + const uint32_t n_threads = MIN(total_tasks, octx->n_threads); + grctx.tasks_per_thread = (total_tasks + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_hvx, &grctx, n_threads); + } return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 9e1b778b01f..2796564fb75 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -17,14 +17,17 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "hmx-profile.h" #include "hmx-queue.h" #include "hmx-utils.h" #include "htp-ctx.h" #include "htp-ops.h" #include "hvx-dump.h" +#include "hvx-copy.h" #include "hvx-reduce.h" #include "hvx-utils.h" +#include "hvx-flash-attn.h" #include "vtcm-utils.h" #include "worker-pool.h" @@ -46,12 +49,12 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong - const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf - const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] @@ -67,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * 1 // V tiles + + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -144,12 +147,13 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 const size_t fp16 = sizeof(__fp16); + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); // Approximate per-unit VTCM costs (without per-buffer alignment padding). const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors const size_t per_gbr2 = fp16; // D diagonal matrix const size_t per_bc = - 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + 3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs const size_t per_gbr_bc = 2 * fp16; // S + P const size_t overhead = 256 * 2 + 13 * 4096; @@ -164,7 +168,6 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. // Only relax when kv_len is too short to form enough blocks. - const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); // Cost coefficients calibrated from profiling @@ -200,7 +203,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, } // Exact VTCM verification (alignment padding may push over budget) - while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) { Bc -= bc_unit; } if (Bc < bc_unit) { @@ -303,6 +306,7 @@ struct hmx_fa_context { uint32_t n_kv_heads; // number of KV heads uint32_t n_heads; // number of Q heads uint32_t G; // GQA factor = n_heads / n_kv_heads + struct fastdiv_values div_G; uint32_t n_kv_blocks; uint32_t neq1; // Q token count @@ -321,7 +325,7 @@ struct hmx_fa_context { __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] __fp16 * vtcm_k_tiles; // K tiles (transposed) - __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_v_tiles[2]; // V tiles (column-major, double-buffered) __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] @@ -402,7 +406,9 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + + hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, (int) args->src_stride, (int) args->n_col_tiles, start, end); } @@ -464,10 +470,10 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { for (size_t r = start; r < end; r += 2) { const bool next_row_valid = (r + 1) < n_rows_g; - const size_t q_idx0 = (r + 0) / G; - const size_t h_idx0 = (r + 0) % G; - const size_t q_idx1 = (r + 1) / G; - const size_t h_idx1 = (r + 1) % G; + const size_t q_idx0 = fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = fastmodulo(r + 0, G, &factx->div_G); + const size_t q_idx1 = fastdiv(r + 1, &factx->div_G); + const size_t h_idx1 = fastmodulo(r + 1, G, &factx->div_G); const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; @@ -567,8 +573,8 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { const uint32_t ib3 = args->ib3; for (size_t r = start; r < end; ++r) { - const size_t q_idx = r / G; - const size_t h_idx = r % G; + const size_t q_idx = fastdiv(r, &factx->div_G); + const size_t h_idx = fastmodulo(r, G, &factx->div_G); // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. @@ -780,11 +786,11 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { if (args->mask_vtcm) { // Read mask from VTCM buffer (DMA'd per KV block). // GQA dedup (scheme B): skip load when qi unchanged. - const size_t qi0 = (r + 0) / G; + const size_t qi0 = fastdiv(r + 0, &factx->div_G); v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t qi1 = (r + 1) / G; + const size_t qi1 = fastdiv(r + 1, &factx->div_G); if (qi1 == qi0) { v_mask1 = v_mask0; // scheme B: reuse — same mask row } else { @@ -794,8 +800,8 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { } else { // Fallback: read mask directly from DDR (when mask->ne[2] > 1). const struct htp_tensor * mask = args->mask; - const size_t q_idx0 = args->q_start + ((r + 0) / G); - const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const size_t q_idx0 = args->q_start + fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = args->kv_head * G + fastmodulo(r + 0, G, &factx->div_G); const uint32_t im2_0 = h_idx0 % mask->ne[2]; const uint32_t im3_0 = args->ib3 % mask->ne[3]; @@ -805,12 +811,12 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t q_idx1 = args->q_start + ((r + 1) / G); + const size_t q_idx1 = args->q_start + fastdiv(r + 1, &factx->div_G); if (q_idx1 == q_idx0) { // scheme B: same mask row in DDR path v_mask1 = v_mask0; } else { - const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const size_t h_idx1 = args->kv_head * G + fastmodulo(r + 1, G, &factx->div_G); const uint32_t im2_1 = h_idx1 % mask->ne[2]; const uint32_t im3_1 = args->ib3 % mask->ne[3]; const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + @@ -1191,14 +1197,13 @@ static void hmx_fa_o_norm_worker(void * data) { // Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. // slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). // When max_bias == 0, all slopes are 1.0 (no ALiBi). -static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, +static __attribute__((noinline)) void fa_compute_slopes( const struct hmx_fa_context * factx, uint32_t kv_head, size_t n_rows_g) { + __fp16 * slopes = factx->vtcm_slopes; if (factx->max_bias == 0.0f) { - for (size_t r = 0; r < n_rows_g; ++r) { - sargs->slopes[r] = 1.0f; - } + hvx_splat_f16_a(slopes, 1.0f, n_rows_g); return; } @@ -1207,10 +1212,32 @@ static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sarg const float m0 = factx->m0; const float m1 = factx->m1; + __fp16 temp_slopes[512] __attribute__((aligned(128))); + if (G <= 32) { + // Fast path: Compute G unique slope values in vector registers + HVX_Vector v_val = hvx_alibi_slopes(kv_head, G, n_head_log2, m0, m1); + + __fp16 temp_slopes_aligned[64] __attribute__((aligned(128))); + hvx_vmem(temp_slopes_aligned) = hvx_vec_f32_to_f16(v_val, Q6_V_vzero()); + + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = temp_slopes_aligned[i]; + } + } else { + // Fallback path: G > 32 (rare configurations) + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = (__fp16)alibi_slope(kv_head * G + i, n_head_log2, m0, m1); + } + } + + // Allocate stack buffer to avoid scalar writes to VTCM (which generates L2 misses) + __fp16 local_slopes[n_rows_g] __attribute__((aligned(128))); for (size_t r = 0; r < n_rows_g; ++r) { - const uint32_t h = kv_head * G + r % G; - sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + local_slopes[r] = temp_slopes[fastmodulo(r, G, &factx->div_G)]; } + + // Copy to VTCM slopes using HVX block copy (both are aligned to 128 bytes) + hvx_copy_f16_aa((uint8_t *)slopes, (const uint8_t *)local_slopes, n_rows_g); } // ============================================================================ @@ -1248,28 +1275,28 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { if (DK % 32 != 0 || DV % 32 != 0) { return HTP_STATUS_NO_SUPPORT; } - if (neq1 < 32) { - return HTP_STATUS_NO_SUPPORT; - } // GQA factor const uint32_t n_kv_heads = k->ne[2]; const uint32_t G = neq2 / n_kv_heads; // Thread count for multi-thread HVX phases - const uint32_t n_threads = octx->n_threads; + const uint32_t n_threads_init = octx->n_threads; // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) size_t Br, Bc; const size_t vtcm_budget = ctx->vtcm_size; - if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads_init) != 0) { return HTP_STATUS_VTCM_TOO_SMALL; } const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + + // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 + const uint32_t n_threads = use_pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); @@ -1278,13 +1305,14 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { struct hmx_fa_context factx; memset(&factx, 0, sizeof(factx)); factx.octx = octx; - factx.n_threads = octx->ctx->n_threads; + factx.n_threads = n_threads; factx.DK = DK; factx.DV = DV; factx.n_kv = nek1; factx.n_kv_heads = n_kv_heads; factx.n_heads = neq2; factx.G = G; + factx.div_G = init_fastdiv_values(G); factx.neq1 = neq1; factx.Br = (uint32_t) Br; factx.Bc = (uint32_t) Bc; @@ -1328,10 +1356,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); // ======== VTCM allocation (GQA-aware) ======== + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(size_k_row, 128); + const size_t size_v_row_padded = hex_round_up(size_v_row, 128); + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); - const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); - const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * size_k_row_padded, 4096); + const size_t v_dma_bytes = hex_align_up(Bc * size_v_row_padded, 4096); const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); @@ -1352,7 +1385,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); - factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + if (use_pipeline) { + factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + } else { + factx.vtcm_v_tiles[1] = NULL; + } factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); @@ -1401,11 +1439,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ======== DMA setup ======== dma_queue * const dma = ctx->dma[0]; - // Padded row sizes for DMA - const size_t size_k_row = nek0 * sizeof(__fp16); - const size_t size_v_row = nev0 * sizeof(__fp16); - const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); - const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + // Padded row sizes for DMA (defined in outer scope) const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; @@ -1459,6 +1493,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ---- KV block loop with DMA double-buffering ---- size_t buf_idx = 0; + fa_compute_slopes(&factx, kv_head, n_rows_g); + // Prefetch first KV block if (factx.n_kv_blocks > 0) { const uint32_t kv_rows0 = hex_smin(Bc, nek1); @@ -1537,7 +1573,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1552,11 +1588,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); TIMER_STOP(k_interleave); - if (kv_blk > 0) { - hmx_queue_pop(hmx_q); - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- qk_job.q_tiles = factx.vtcm_q_tiles; qk_job.k_tiles = factx.vtcm_k_tiles; @@ -1576,6 +1607,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); TIMER_STOP(v_interleave); + // Pop and swap previous block's output update (deferred HMX pop) + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // Pop current block's dot product job hmx_queue_pop(hmx_q); TIMER_STOP(qk_dot); @@ -1603,7 +1641,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1619,7 +1656,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1714,7 +1751,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1734,7 +1770,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t DV_tiles = (size_t) (DV / 32); const __fp16 * restrict d_base = factx.vtcm_d_tiles; const __fp16 * restrict p_base = factx.vtcm_p_tiles; - const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles[0]; const __fp16 * restrict op_base = o_tile_prev; __fp16 * restrict oc_base = o_tile_curr; __builtin_assume(n_row_tiles > 0); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 3ef0bcdb26d..dab605210cf 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -16,6 +16,7 @@ #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "worker-pool.h" #include "hvx-utils.h" @@ -34,6 +35,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; +static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, +}; + // MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value // kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { @@ -62,10 +67,16 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb case HTP_TYPE_Q8_0: return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); default: return 0; } @@ -181,45 +192,134 @@ static int hmx_compute_chunks(size_t vtcm_total, // In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles // of the same 32 packed bytes. static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using +// full HVX vector width. +// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. +static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); + + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); + + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); + + HVX_Vector vscale = hvx_vmemu(scales_4); + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + +static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_dm = hvx_vmemu(scale_offset); + HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); + HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); + + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); + + HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); + + HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); + HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + +// LUT-based dequantizers for non-linear IQ4_NL format. +static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - // Shuffle before LUT v_quants = Q6_Vb_vshuff_Vb(v_quants); - // Use standard vlut16 (not _nomatch) to avoid stale-register NaN. - // _nomatch retains the previous destination-register value for colliding - // indices, but the C intrinsic doesn't model the implicit read so the - // compiler may allocate a register containing garbage/NaN. HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); HVX_Vector v_hf = Q6_V_lo_W(vp); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } -// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using -// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. -// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. -static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( +static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - // Shuffle before LUT v_quants = Q6_Vb_vshuff_Vb(v_quants); - // Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16] - HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); - // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -227,9 +327,7 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ - v_hi /* group2 already in [0:63] */ }; + HVX_Vector_x2 r = { v_lo, v_hi }; return r; } @@ -320,100 +418,176 @@ static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * return r; } +typedef struct { + __fp16 *dst; + const uint8_t *src; + int n_cols; + int k_block; + size_t row_stride; + int weight_type; + int n_tot_tiles; + int n_tiles_per_task; + int n_tasks; + int n_k_tiles; + struct fastdiv_values n_k_tiles_div; +} x4x2_dequantize_state_t; + // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. // Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. // Output: vtcm_dst in tile-major FP16 layout. -static void dequantize_x4x2_weight_to_fp16_tiles_task( - __fp16 *restrict vtcm_dst, - const uint8_t *restrict vtcm_src, - int n_cols, int k_block, - size_t row_stride, int weight_type, - int start_tile, int end_tile) { - - const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); - const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; - const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : - (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : - hvx_vmem(q4_0_to_fp16_lut); - - // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. - // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 - // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) - - unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index - unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index - for (unsigned t = start_tile; t < end_tile; ) { - if (kt >= n_k_tiles) { kt = 0; ct++; } - - // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- - if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; - unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 - bool upper = (sub_blk_base >= 4); - unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE - + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales - - __fp16 *tile_bases[4]; - for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } +#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \ +static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \ + const x4x2_dequantize_state_t *state, \ + int start_tile, int end_tile) { \ + \ + const int n_k_tiles = state->n_k_tiles; \ + const int qrow_size = (unsigned)state->k_block / 2; \ + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \ + const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \ + \ + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \ + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \ + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \ + \ + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \ + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \ + \ + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \ + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \ + \ + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \ + unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \ + unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk_base >= 4); \ + unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \ + \ + __fp16 *tile_bases[4]; \ + for (unsigned g = 0; g < 4; g++) { \ + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \ + } \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + \ + HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + \ + for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \ + t += 4; kt += 4; \ + continue; \ + } \ + \ + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \ + { \ + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \ + unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk >= 4); \ + unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \ + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \ + ? dequantize_x4x2_##helper_prefix##_group_hvx( \ + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \ + : Q6_V_vzero(); \ + \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + (void) *(volatile HVX_Vector *)(tile_base); \ + } \ + ++t; ++kt; \ + } \ + \ + if (start_tile < end_tile) { \ + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \ + } \ +} \ + \ +static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \ + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \ + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ + int start = task_id * state->n_tiles_per_task; \ + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ + dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \ + } \ +} - HVX_Vector v_off = v_scat_base; +DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) +DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4) +DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; +static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = (unsigned)state->k_block / 2; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } - t += 4; kt += 4; - continue; - } - - // --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales --- - if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + // Batch-4 fast path for MXFP4 + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4 + int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales + int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; __fp16 * tile_bases[4]; for (int g = 0; g < 4; g++) { - tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; } HVX_Vector v_off = v_scat_base; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t * r0 = vtcm_src + row0 * row_stride; - const uint8_t * r1 = vtcm_src + row1 * row_stride; + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; - // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); HVX_Vector_x4 dv0, dv1; dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); - if (row1 < n_cols) { + if (row1 < state->n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { @@ -434,41 +608,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( (void) *(volatile HVX_Vector *) (tile_bases[g]); } - t += 4; + t += 4; kt += 4; continue; } - // --- Single-tile fallback --- - __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - - if (is_q4) { - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; - unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; - bool upper = (sub_blk >= 4); - unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); - - HVX_Vector v_off = v_scat_base; // reset to column 0 - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_0_group_hvx( - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } else if (weight_type == HTP_TYPE_MXFP4) { + // Single-tile fallback + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { int blk_idx = (kt * 32) / QK_MXFP4x4x2; int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; bool upper = (sub_blk >= 4); @@ -480,15 +626,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t * r0 = vtcm_src + row0 * row_stride; - const uint8_t * r1 = vtcm_src + row1 * row_stride; + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; - // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); HVX_Vector v1; - if (row1 < n_cols) { + if (row1 < state->n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); } else { @@ -501,23 +646,59 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } (void) *(volatile HVX_Vector *) (tile_base); - } else { - // Q8_0 + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end); + } +} + +static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = state->k_block; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { int blk_idx = (kt * 32) / QK_Q8_0x4x2; int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); - HVX_Vector v_off = v_scat_base; // reset to column 0 + HVX_Vector v_off = v_scat_base; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t *r0 = vtcm_src + row0 * row_stride; - const uint8_t *r1 = vtcm_src + row1 * row_stride; + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); + HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); @@ -529,53 +710,151 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( ++t; ++kt; } - // Drain HVX scatter write buffer: a vmem load on the same HW thread retires - // all pending scatter entries to VTCM. Without this, the main thread's HMX - // reads may see stale data because atomic_fetch_sub (release) only orders - // regular stores, not the HVX scatter buffer. if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); } } -typedef struct { - __fp16 *dst; - const uint8_t *src; - int n_cols; - int k_block; - size_t row_stride; - int weight_type; - int n_tot_tiles; - int n_tiles_per_task; - int n_tasks; -} x4x2_dequantize_state_t; - -static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) { +static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) { x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end); + } +} + +static void convert_f16_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { int start = task_id * state->n_tiles_per_task; int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } +} - dequantize_x4x2_weight_to_fp16_tiles_task( - state->dst, state->src, state->n_cols, state->k_block, - state->row_stride, state->weight_type, start, end); +static void quantize_f32_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); } } +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } +} + + static void dequantize_x4x2_weight_chunk_to_fp16_tiles( struct htp_context *ctx, __fp16 *vtcm_dst, const void *vtcm_src, int n_cols, int k_block, - size_t row_stride, int weight_type) { + size_t row_stride, int weight_type, + int n_k_tiles, struct fastdiv_values n_k_tiles_div, + worker_callback_t dequant_worker_fn, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; size_t n_tot_tiles = n_col_tiles * n_k_tiles; - size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); x4x2_dequantize_state_t state; state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; @@ -587,12 +866,20 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.k_block = k_block; state.row_stride = row_stride; state.weight_type = weight_type; + state.n_k_tiles = n_k_tiles; + state.n_k_tiles_div = n_k_tiles_div; - worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); + } } // --- End x4x2 dequantizers --- +#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics + // requires external HMX lock static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, int n_row_tiles, int n_col_tiles, int n_dot_tiles) { @@ -714,11 +1001,11 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void } static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n) { + int n_rows, int n_cols, int n, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) output_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -729,7 +1016,11 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, state.n_cols = n_cols; state.n = n; - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); + } } // activations : fp32 -> fp16 @@ -811,12 +1102,12 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, } } -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); assert(VLEN == 32 * sizeof(float)); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address activation_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -827,7 +1118,11 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * state.k_block = k_block; state.k_stride = k_stride; - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); + } } // C += AB @@ -869,9 +1164,9 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co } } -int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, - int weight_type) { + int act_stride, int weight_stride, int weight_type) { if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { @@ -883,21 +1178,41 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * return -1; } + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + // --- Dynamic Mode Configuration --- + const bool use_pipeline = (m > 32); + const int num_threads = (m <= 32) ? 1 : ctx->n_threads; + // --- Dynamic VTCM layout --- const size_t vec_dot_size = k * sizeof(__fp16); const size_t vtcm_budget = ctx->vtcm_size; size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); return -1; } @@ -907,27 +1222,27 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * size_t scratch0_size, scratch1_size, scratch2_size; scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 + scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); @@ -938,115 +1253,137 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * TIMER_DEFINE(total); TIMER_START(total); - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] + if (use_pipeline) { + // --- Asynchronous Pipelined Loop (Current implementation) --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); + } - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); - } + { + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); + } - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); - } - } + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); + } - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } + } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { + // Load Activation + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + // A: DMA Load Weight + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); + + // B: Dequantize / Convert Weight + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } - hmx_queue_suspend(ctx->hmx_queue); - TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); + FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1225,11 +1562,11 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 dma_queue_pop(ctx->dma[0]); transfer_activation_chunk_threaded(ctx, vtcm_act_g, vtcm_f32_act, (int) n_rows, - params->k, params->k); + params->k, params->k, ctx->n_threads); } else { transfer_activation_chunk_threaded(ctx, vtcm_act_g, activation_chunk, (int) n_rows, - params->k, params->act_stride); + params->k, params->act_stride, ctx->n_threads); } } TIMER_STOP(activation_load); @@ -1279,7 +1616,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_START(output_store); { float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); } TIMER_STOP(output_store); } @@ -1299,177 +1636,431 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); #endif - return 0; + return 0; } -// - int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } - if (k % 32 != 0 || n % 32 != 0) { return -1; } + return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, + act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); +} - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + int ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + int start_row; + int cne1; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + int start_row; + int cne1; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + int cne1) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + int r_idx0 = start_row + r + 0; + int r_idx1 = start_row + r + 1; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } } - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } } +} - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; + transfer_activation_chunk_fp32_to_fp16_gathered( + dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); + } +} - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - return -1; +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = init_fastdiv_values(ne11), + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_activation_chunk_gathered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); } +} - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + const HVX_Vector one = hvx_vec_splat_f16(1.0); - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - TIMER_DEFINE(total); - TIMER_START(total); + int r_idx0 = start_row + (int)r + 0; + int r_idx1 = start_row + (int)r + 1; - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + if (r_idx0 >= cne1) break; - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); + volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); } } - TIMER_STOP(activation_load); + } +} - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + } +} - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } + if (k % 32 != 0 || n % 32 != 0) { return -1; } - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); + const int num_threads = ctx->n_threads; - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); + const size_t size_per_n = row_stride + vec_dot_size; + const size_t size_per_mn = sizeof(__fp16); - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); - TIMER_STOP(total); + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); + + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); + dma_queue_pop(ctx->dma[0]); + + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + transfer_output_chunk_scattered_threaded( + ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); + } } -#endif + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c new file mode 100644 index 00000000000..114d8c14811 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.c @@ -0,0 +1,6 @@ +// HMX operations compiled as a single translation unit. +// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. + +#include "hmx-queue.c" +#include "hmx-matmul-ops.c" +#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index f114edb822f..a67842f3ffc 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -52,14 +52,32 @@ int hmx_matmul_f16_f32(struct htp_context *ctx, // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_q_f32(struct htp_context *ctx, +// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, int m, int k, int n, + int act_stride, + int weight_stride, int weight_type); +struct mmid_row_mapping; + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride); + // HMX flash attention int hmx_flash_attn_ext(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 6fe3e6c7d85..0f1676f077a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -79,6 +79,10 @@ struct htp_context { uint64_t max_vmem; + // Persistent DDR scratchpad for MUL_MAT_ID mappings + void * ddr_spad_base; + size_t ddr_spad_size; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX @@ -104,6 +108,7 @@ int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); +int op_concat(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 9d905a30133..fa85bf4ca0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -20,6 +20,7 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, HTP_TYPE_IQ4_NL = 20, HTP_TYPE_I32 = 26, @@ -28,6 +29,7 @@ enum htp_data_type { // types used internally for repack, dyn.quant, etc HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2, HTP_TYPE_Q8_0x4x2, HTP_TYPE_MXFP4x4x2, @@ -56,6 +58,7 @@ enum htp_op_code { HTP_OP_MUL_MAT, HTP_OP_MUL_MAT_ID, HTP_OP_RMS_NORM, + HTP_OP_RMS_NORM_MUL, HTP_OP_UNARY_SILU, HTP_OP_UNARY_GELU, HTP_OP_UNARY_SIGMOID, @@ -89,6 +92,7 @@ enum htp_op_code { HTP_OP_TRI, HTP_OP_PAD, HTP_OP_NORM, + HTP_OP_CONCAT, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h new file mode 100644 index 00000000000..f1f2e49e455 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h @@ -0,0 +1,47 @@ +#ifndef HVX_FLASH_ATTN_H +#define HVX_FLASH_ATTN_H + +#include <math.h> +#include "hvx-utils.h" + +// Scalar helper to compute a single ALiBi slope. +static inline float alibi_slope(uint32_t h, uint32_t n_head_log2, float m0, float m1) { + return (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); +} + +// Vectorized helper to compute 32 ALiBi slopes starting from (kv_head * G). +static inline HVX_Vector hvx_alibi_slopes( + uint32_t kv_head, + uint32_t G, + uint32_t n_head_log2, + float m0, + float m1 +) { + static const float ramp_32[32] __attribute__((aligned(128))) = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f + }; + HVX_Vector v_ramp = hvx_vmem(ramp_32); + HVX_Vector v_h_base = hvx_vec_splat_f32((float)(kv_head * G)); + HVX_Vector v_h = hvx_vec_add_f32_f32(v_h_base, v_ramp); + + // Compute exponent_m0: h + 1 + HVX_Vector v_exp_m0 = hvx_vec_add_f32_f32(v_h, hvx_vec_splat_f32(1.0f)); + + // Compute exponent_m1: 2 * (h - n_head_log2) + 1 + HVX_Vector v_n_head_log2 = hvx_vec_splat_f32((float)n_head_log2); + HVX_Vector v_h_minus = hvx_vec_sub_f32_f32(v_h, v_n_head_log2); + HVX_Vector v_exp_m1 = hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(hvx_vec_splat_f32(2.0f), v_h_minus), hvx_vec_splat_f32(1.0f)); + + // Compute powers + HVX_Vector v_pow_m0 = hvx_vec_pow_const_base_f32(m0, v_exp_m0); + HVX_Vector v_pow_m1 = hvx_vec_pow_const_base_f32(m1, v_exp_m1); + + // Select based on h < n_head_log2 + HVX_VectorPred p_cond = Q6_Q_vcmp_gt_VsfVsf(v_n_head_log2, v_h); // v_n_head_log2 > v_h <=> h < n_head_log2 + return Q6_V_vmux_QVV(p_cond, v_pow_m0, v_pow_m1); +} + +#endif /* HVX_FLASH_ATTN_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-log.h b/ggml/src/ggml-hexagon/htp/hvx-log.h new file mode 100644 index 00000000000..7013dae785a --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-log.h @@ -0,0 +1,65 @@ +#ifndef HVX_LOG_H +#define HVX_LOG_H + +#include "hvx-base.h" + +// Approximates ln(x) element-wise for float vectors. +// x must contain positive float elements. +// Uses Abramowitz & Stegun polynomial approximation 4.1.44 for ln(1+y) over [0, 1]. +static inline HVX_Vector hvx_vec_log_f32(HVX_Vector x) { + // x = m * 2^e, where m in [1, 2) + HVX_Vector biased_e = Q6_Vuw_vlsr_VuwR(x, 23); + HVX_Vector e_int = Q6_Vw_vsub_VwVw(biased_e, Q6_V_vsplat_R(127)); + HVX_Vector e_float = Q6_Vsf_equals_Vw(e_int); + + // Extract mantissa and set exponent to 127 (which represents float value in [1.0, 2.0)) + HVX_Vector mant_mask = Q6_V_vsplat_R(0x007FFFFF); + HVX_Vector exp_127 = Q6_V_vsplat_R(0x3F800000); + HVX_Vector m = Q6_V_vor_VV(Q6_V_vand_VV(x, mant_mask), exp_127); + + // y = m - 1.0f, y in [0, 1) + HVX_Vector y = hvx_vec_sub_f32_f32(m, hvx_vec_splat_f32(1.0f)); + + // Abramowitz & Stegun 4.1.44 polynomial approximation of ln(1+y) + HVX_Vector c; + HVX_Vector res; + + c = hvx_vec_splat_f32(-0.0064535442f); + res = hvx_vec_mul_f32_f32(y, c); + + c = hvx_vec_splat_f32(0.0360884937f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.0953293897f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.1676540711f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.2407338084f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.3317990258f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.4998741238f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.9999964239f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + // ln(x) = e * ln(2) + ln(1+y) + HVX_Vector ln2 = hvx_vec_splat_f32(0.69314718056f); + HVX_Vector term_e = hvx_vec_mul_f32_f32(e_float, ln2); + + return hvx_vec_add_f32_f32(term_e, res); +} + +#endif /* HVX_LOG_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-pow.h b/ggml/src/ggml-hexagon/htp/hvx-pow.h new file mode 100644 index 00000000000..48fe0e8eade --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-pow.h @@ -0,0 +1,42 @@ +#ifndef HVX_POW_H +#define HVX_POW_H + +#include <math.h> +#include "hvx-base.h" +#include "hvx-exp.h" +#include "hvx-log.h" + +// Approximates base^exponent element-wise for float vectors. +// base must be a positive constant. exponent is an HVX f32 vector. +// Uses base^x = exp(x * ln(base)). +static inline HVX_Vector hvx_vec_pow_const_base_f32(float base, HVX_Vector exponent) { + float ln_base = logf(base); + HVX_Vector ln_base_v = hvx_vec_splat_f32(ln_base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base_v); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +// Approximates base^exponent element-wise for float vectors. +// base and exponent are HVX f32 vectors. base elements must be positive. +// Uses base^exponent = exp(exponent * ln(base)). +static inline HVX_Vector hvx_vec_pow_f32(HVX_Vector base, HVX_Vector exponent) { + HVX_Vector ln_base = hvx_vec_log_f32(base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +#endif /* HVX_POW_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h new file mode 100644 index 00000000000..c5b9a5d47c1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h @@ -0,0 +1,90 @@ +#ifndef HVX_SIN_COS_H +#define HVX_SIN_COS_H + +#include "hvx-base.h" +#include "hvx-floor.h" + +static inline HVX_Vector hvx_vec_cos_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for cos(y) + HVX_Vector c4 = hvx_vec_splat_f32(2.3557242013849433e-05f); + HVX_Vector c3 = hvx_vec_splat_f32(-0.0013871428263450528f); + HVX_Vector c2 = hvx_vec_splat_f32(0.041665895266688284f); + HVX_Vector c1 = hvx_vec_splat_f32(-0.4999999360426369f); + HVX_Vector c0 = hvx_vec_splat_f32(0.9999999999071725f); + + HVX_Vector cos_y = hvx_vec_add_f32_f32(c3, hvx_vec_mul_f32_f32(z, c4)); + cos_y = hvx_vec_add_f32_f32(c2, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c1, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c0, hvx_vec_mul_f32_f32(z, cos_y)); + + return hvx_vec_mul_f32_f32(cos_y, sign); +} + +static inline HVX_Vector hvx_vec_sin_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for sin(y) + HVX_Vector s4 = hvx_vec_splat_f32(2.642186986152672e-06f); + HVX_Vector s3 = hvx_vec_splat_f32(-0.00019825318964070864f); + HVX_Vector s2 = hvx_vec_splat_f32(0.00833326283319605f); + HVX_Vector s1 = hvx_vec_splat_f32(-0.16666666082087775f); + HVX_Vector s0 = hvx_vec_splat_f32(0.999999999915155f); + + HVX_Vector sin_y = hvx_vec_add_f32_f32(s3, hvx_vec_mul_f32_f32(z, s4)); + sin_y = hvx_vec_add_f32_f32(s2, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s1, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s0, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_mul_f32_f32(y, sin_y); + + return hvx_vec_mul_f32_f32(sin_y, sign); +} + +#endif /* HVX_SIN_COS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index e0452811ec3..23373f73ae2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -14,6 +14,10 @@ #include "hvx-sqrt.h" #include "hvx-arith.h" #include "hvx-div.h" +#include "hvx-floor.h" +#include "hvx-sin-cos.h" #include "hvx-base.h" +#include "hvx-pow.h" +#include "hvx-log.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e8619388478..3715227d2c7 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -12,6 +12,7 @@ #include <HAP_mem.h> #include <HAP_power.h> #include <HAP_ps.h> +#include <HAP_dcvs.h> #include <qurt.h> #include <qurt_thread.h> #include <qurt_memory.h> @@ -63,8 +64,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_DCVS_v3; request.dcvs_v3.set_dcvs_enable = TRUE; - request.dcvs_v3.dcvs_enable = TRUE; - request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.dcvs_enable = FALSE; request.dcvs_v3.set_bus_params = TRUE; request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; @@ -75,6 +75,10 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.set_sleep_disable = TRUE; request.dcvs_v3.sleep_disable = TRUE; + +#if (__HEXAGON_ARCH__ >= 79) + HAP_set_dcvs_v3_protected_bus_corners(&request, 1); +#endif if ((err = HAP_power_set((void *) ctx, &request)) != 0) { return err; } @@ -103,7 +107,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Setting HMX clock\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "ggml-hex: error setting HMX clock."); return err; } } @@ -117,7 +121,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Powering HMX on\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "ggml-hex: error powering on HMX."); return err; } } @@ -420,14 +424,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { - // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(128); + ctx->dma[i] = dma_queue_create(256); // queue depth } + ctx->ddr_spad_size = 512 * 1024; // 512 KB + ctx->ddr_spad_base = memalign(128, ctx->ddr_spad_size); + // init worker pool err = worker_pool_init(&ctx->worker_pool, n_hvx); if (err != AEE_SUCCESS) { FARF(ERROR, "Unable to create worker pool"); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } return err; } @@ -475,6 +486,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) { vtcm_free(ctx); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } + return AEE_SUCCESS; } @@ -538,6 +555,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_NORM: case HTP_OP_RMS_NORM: + case HTP_OP_RMS_NORM_MUL: case HTP_OP_SCALE: case HTP_OP_SQR: case HTP_OP_SQRT: @@ -601,6 +619,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_PAD: return op_pad(octx); + case HTP_OP_CONCAT: + return op_concat(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); @@ -851,6 +872,11 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; + if (i == (n_ops-1)) { + // wake up the host before starting the last op + dspqueue_write_early_wakeup_noblock(queue, 0, 0); + } + profile_start(ctx->profiler, &prof); proc_op_req(octx, tens, i, &ops[i]); @@ -867,8 +893,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { } } - // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); - struct htp_opbatch_rsp rsp; rsp.id = req.id; rsp.status = HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 46fc5602dc9..5121c6f9bad 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -40,6 +40,11 @@ struct htp_matmul_context { const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); + void (*vec_dot_4x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0); + // Precomputed values uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; @@ -48,6 +53,11 @@ struct htp_matmul_context { struct fastdiv_values mm_div_ne1; struct fastdiv_values mm_div_r2; struct fastdiv_values mm_div_r3; + + // Fields for scattered mapping & HMX support in MUL_MAT_ID + const uint32_t * matrix_row_counts; + const struct mmid_row_mapping * matrix_rows; + bool hmx_eligible; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -155,6 +165,13 @@ static inline size_t q8x4x2_row_size(uint32_t ne) { return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); } +static inline size_t q8_1x4x2_row_size(uint32_t ne) { + // ensures perfect alignment of quants and full row + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); +} + static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -223,6 +240,62 @@ static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, u return r; } +static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = v0; + r.v[i*2+1] = v1; + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_V_lo_W(v0_1_p); + r.v[i*2+1] = Q6_V_hi_W(v0_1_p); + } + + return r; +} + static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -401,82 +474,96 @@ static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 return hvx_vec_rmpy_x8_partial(x, y, 512); } -static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, +static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size @@ -486,11 +573,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -500,77 +587,306 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r2_ms = Q6_V_vand_QV(bmask, r2_ms); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r3_ms = Q6_V_vand_QV(bmask, r3_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { assert(n % 32 == 0); @@ -581,11 +897,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -595,9 +911,9 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums // Row sums (sf) - 4 accumulators for 2×2 tile HVX_Vector r0_c0_sum = Q6_V_vzero(); @@ -610,13 +926,13 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * uint32_t i = 0; for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) + // Load src1 columns HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + // Load src0 rows + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -625,16 +941,38 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); // Compute combined scales HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); // Apply scales and accumulate HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); @@ -642,40 +980,72 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); @@ -686,10 +1056,15 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); } // Reduce and store results @@ -700,26 +1075,26 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); @@ -729,12 +1104,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo // Apply scale to acc and accumulate into the row sum (qf32). const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) + const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); @@ -751,7 +1126,433 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); @@ -869,6 +1670,133 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -1069,19 +1997,112 @@ static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, +static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, + const void * restrict vx3, const void * restrict vy0) { assert(n % 32 == 0); assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; @@ -1098,12 +2119,18 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); const uint32_t nb = n / qk; const uint32_t nloe = n % qk; @@ -1113,56 +2140,86 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); } + static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, @@ -1521,6 +2578,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); + + // Zero-out unused values + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -1691,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); +} + +static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector rsum0 = Q6_V_vzero(); + HVX_Vector rsum1 = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_sf = y[i]; + HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); + HVX_VectorAlias va; + va.v = rsum; + s0[0] = va.fp32[0]; + s0[1] = va.fp32[1]; +} + +static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_sf = x0[i]; + HVX_Vector r1_sf = x1[i]; + HVX_Vector c0_sf = y0[i]; + HVX_Vector c1_sf = y1[i]; + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + + HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]); + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + HVX_VectorAlias va0, va1; + va0.v = r0_r1_c0_sum; + va1.v = r0_r1_c1_sum; + s0[0] = va0.fp32[0]; + s0[1] = va0.fp32[1]; + s1[0] = va1.fp32[0]; + s1[1] = va1.fp32[1]; +} + +static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + if (nloe) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + x_sf = Q6_V_vand_QV(bmask, x_sf); + y_sf = Q6_V_vand_QV(bmask, y_sf); + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + rsum = hvx_vec_reduce_sum_f32(rsum); + hvx_vec_store_u(&s[0], 4, rsum); +} + static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -2109,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row); + const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -2138,7 +3535,6 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { @@ -2168,39 +3564,89 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; + if (mmctx->vec_dot_4x1 != NULL) { + const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); + + // Prefill spad with 4x src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 4); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x4) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 4); + } + } + + // Process leftovers + uint32_t ir0 = src0_end_row_x4; + if (ir0 + 2 <= src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + ir0 += 2; + } + if (ir0 < src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + ir0 += 1; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + } else { + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + // Prefill spad with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 2); } - } - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + } } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); @@ -2245,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert - const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); - const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - - const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; - const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; + const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; @@ -2271,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { continue; } + if (mmctx->hmx_eligible) { + continue; + } + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); // Prefill spad with src0 rows @@ -2312,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -2414,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -2432,6 +3879,94 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // *** dynamic quant +static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; + + // --- Sum calculation --- + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements + // Sum 8 elements: + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + // Copy to stack to extract sums and vmaxes + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + __fp16 * y_d_half = (__fp16 *) y_d; + y_d_half[0] = d0; + y_d_half[1] = (float) sums[0] * d0; + y_d_half[2] = d1; + y_d_half[3] = (float) sums[8] * d1; + y_d_half[4] = d2; + y_d_half[5] = (float) sums[16] * d2; + y_d_half[6] = d3; + y_d_half[7] = (float) sums[24] * d3; +} + static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); @@ -2656,6 +4191,118 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes + const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales/sums + + // Temp scales override input since we're working off of the aligned temp buffer in VTCM + uint8_t * restrict t_d = (uint8_t *) x; + + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + } + + // now copy the scales/sums into final location + hvx_copy_f16_ua(y_d, t_d, nb * 16); +} + +static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8_1x4x2_row_size(ne0); + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -2751,24 +4398,35 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_Q8_0: mmctx->type = "q8x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_IQ4_NL: mmctx->type = "iq4nlx4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; return 0; default: return -1; @@ -2887,6 +4545,60 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; + } + } else if (src0->type == HTP_TYPE_F32) { + // Try optimized f32-f32 path first (src1 in VTCM) + const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); + const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); + const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + + const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; + + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); + + if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = quantize_f32_f32; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + + src1_row_size = f32_src1_row_size; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to DDR / broadcasting + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + matmul_job_func = matmul_4d; + + src1_row_size = nb11; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; } } else { @@ -2894,8 +4606,13 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } @@ -2959,20 +4676,20 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 HMX path requires K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + // F16 and F32 HMX paths require K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { return op_matmul_hvx(octx); } - if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { return op_matmul_hvx(octx); } @@ -3017,8 +4734,8 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - if (src0->type == HTP_TYPE_F16) { - if (is_batched) { + if (is_batched) { + if (src0->type == HTP_TYPE_F16) { hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, @@ -3042,13 +4759,11 @@ int op_matmul(struct htp_ops_context * octx) { }; ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_matmul_f16_f32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_total, k, n, act_stride, wgt_stride); + return op_matmul_hvx(octx); } } else { - ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, (int) src0->type); + ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); } if (ret != 0) { @@ -3093,15 +4808,42 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size; + + void * mapping_buf = NULL; + bool must_free_mapping = false; + + if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) { + mapping_buf = octx->ctx->ddr_spad_base; + } else { + mapping_buf = memalign(128, total_map_size); + if (mapping_buf) { + must_free_mapping = true; + } else { + return HTP_STATUS_INTERNAL_ERR; + } + } + + uint32_t * matrix_row_counts = (uint32_t *) mapping_buf; + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size); + + mmctx->matrix_row_counts = matrix_row_counts; + mmctx->matrix_rows = matrix_rows; if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } - const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; @@ -3117,6 +4859,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -3136,9 +4879,6 @@ int op_matmul_id(struct htp_ops_context * octx) { if (src1_nrows > 1) { // initialize matrix_row_counts and map - uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; - struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; - memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); // group rows by src0 matrix @@ -3148,14 +4888,60 @@ int op_matmul_id(struct htp_ops_context * octx) { assert(i02 >= 0 && i02 < n_as); - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; } } } - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; + } + + bool hmx_eligible = false; +#ifdef HTP_HAS_HMX + if (octx->ctx->hmx_enabled && src1_nrows > 1) { + uint32_t wtype = src0->type; + if (ne01 % 32 == 0 && + (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { + hmx_eligible = true; + } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { + hmx_eligible = true; + } + } + } +#endif + + mmctx->hmx_eligible = hmx_eligible; + + if (hmx_eligible) { + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * ids->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + // HMX has overwritten VTCM, so force dynamic quantization cache to clear + octx->src1_spad.src = NULL; + + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; + } if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -3167,5 +4953,6 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t n_matmul_jobs = octx->n_threads; worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c index 3abc3c2ead1..aaa72b31590 100644 --- a/ggml/src/ggml-hexagon/htp/pad-ops.c +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -511,6 +511,8 @@ int op_pad(struct htp_ops_context * octx) { octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.data = octx->ctx->vtcm_base; octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; } struct htp_pad_context pctx = { diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index b398e19f06e..c839044b84f 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -7,6 +7,7 @@ #include <math.h> #include <string.h> +#include <stdlib.h> #include "hex-dma.h" #include "hvx-utils.h" @@ -75,6 +76,9 @@ struct htp_rope_context { size_t theta_cache_offset; uint32_t src0_nrows; + struct fastdiv_values div_ne2_ne1; + struct fastdiv_values div_ne1; + uint64_t t_start; }; @@ -117,13 +121,84 @@ static __attribute__((noinline)) void rope_cache_init(const float theta_base, float * cache, const float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - float theta = theta_base; +#if __HVX_ARCH__ >= 79 + const bool is_v79_or_newer = true; +#else + const bool is_v79_or_newer = false; +#endif + + if (is_v79_or_newer && ext_factor == 0.0f) { + // Fast path: fully vectorized + // We process 32 pairs (64 elements) per iteration. + const uint32_t n_blocks = ne0 / 64; + + // Initialize theta scale powers: [1.0f, theta_scale, theta_scale^2, ..., theta_scale^31] + float __attribute__((aligned(128))) theta_powers[32]; + theta_powers[0] = 1.0f; + for (int j = 1; j < 32; j++) { + theta_powers[j] = theta_powers[j - 1] * theta_scale; + } + HVX_Vector v_theta_powers = hvx_vmem(theta_powers); - for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { - const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; - rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + HVX_Vector v_freq_scale = hvx_vec_splat_f32(freq_scale); + HVX_Vector v_mscale = hvx_vec_splat_f32(mscale); + + // Base theta starts at theta_base + float theta_block = theta_base; + // The scale factor for the next block is theta_scale^32 + float theta_scale_32 = 1.0f; + for (int j = 0; j < 32; j++) { + theta_scale_32 *= theta_scale; + } + + for (uint32_t b = 0; b < n_blocks; b++) { + uint32_t i0 = b * 64; + HVX_Vector v_theta_base = hvx_vec_splat_f32(theta_block); + HVX_Vector v_theta = hvx_vec_mul_f32_f32(v_theta_base, v_theta_powers); + + if (freq_factors) { + // Load 32 elements of freq_factors + HVX_Vector v_ff = hvx_vmemu(freq_factors + i0 / 2); + HVX_Vector v_inv_ff = hvx_vec_inverse_f32(v_ff); + v_theta = hvx_vec_mul_f32_f32(v_theta, v_inv_ff); + } + + HVX_Vector v_theta_final = hvx_vec_mul_f32_f32(v_theta, v_freq_scale); + + HVX_Vector vcos = hvx_vec_cos_f32(v_theta_final); + HVX_Vector vsin = hvx_vec_sin_f32(v_theta_final); + + vcos = hvx_vec_mul_f32_f32(vcos, v_mscale); + vsin = hvx_vec_mul_f32_f32(vsin, v_mscale); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(vsin, vcos, -4); - theta *= theta_scale; + if (((uintptr_t)cache) % 128 == 0) { + hvx_vmem(cache + i0 + 0) = Q6_V_lo_W(vstore); + hvx_vmem(cache + i0 + 32) = Q6_V_hi_W(vstore); + } else { + hvx_vec_store_u(cache + i0 + 0, 32 * sizeof(float), Q6_V_lo_W(vstore)); + hvx_vec_store_u(cache + i0 + 32, 32 * sizeof(float), Q6_V_hi_W(vstore)); + } + + theta_block *= theta_scale_32; + } + + // Leftovers + float theta = theta_block; + for (uint32_t i0 = n_blocks * 64; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } + } else { + // Fallback to original scalar loop + float theta = theta_base; + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } } } @@ -195,24 +270,18 @@ static void rope_corr_dims(int n_dims, } static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2 + const uint32_t he = ne / 2; + const uint32_t nvec = he / 32; + const uint32_t nloe = he % 32; - uint32_t he = ne / 2; // half_dims offset in elements - uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i]; + HVX_Vector v1 = hvx_vmemu(src0 + he + i * 32); - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i += 2) { - HVX_Vector v0 = vsrc[i/2+0]; - HVX_Vector v1 = vsrc[i/2+hv]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; - - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); @@ -222,37 +291,45 @@ static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * rest HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4); - vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5); + ((HVX_Vector *) dst)[i] = Q6_Vsf_equals_Vqf32(v4); + hvx_vmemu(dst + he + i * 32) = Q6_Vsf_equals_Vqf32(v5); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i/2]; - float x1 = src0[i/2 + he]; - dst[i/2] = x0 * cos_theta - x1 * sin_theta; - dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 32); + HVX_Vector v1 = hvx_vmemu(src0 + he + nvec * 32); + + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 1]; + + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + hvx_vec_store_u(dst + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v4)); + hvx_vec_store_u(dst + he + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v5)); } } static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two + const uint32_t nvec = ne / 64; + const uint32_t nloe = ne % 64; - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i+=2) { - HVX_Vector v0 = vsrc[i+0]; - HVX_Vector v1 = vsrc[i+1]; + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i * 2 + 0]; + HVX_Vector v1 = ((const HVX_Vector *) src0)[i * 2 + 1]; - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); @@ -264,17 +341,52 @@ static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - vdst[i+0] = Q6_V_lo_W(vstore); - vdst[i+1] = Q6_V_hi_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 0] = Q6_V_lo_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 1] = Q6_V_hi_W(vstore); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i+0]; - float x1 = src0[i+1]; - dst[i+0] = x0 * cos_theta - x1 * sin_theta; - dst[i+1] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + if (nloe <= 32) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(Q6_V_vzero(), v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(Q6_V_vzero(), v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + hvx_vec_store_u(dst + nvec * 64, nloe * sizeof(float), Q6_V_lo_W(vstore)); + } else { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v1 = hvx_vmemu(src0 + nvec * 64 + 32); + + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + HVX_Vector v3 = hvx_vmemu(theta_cache + nvec * 64 + 32); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + ((HVX_Vector *) dst)[nvec * 2 + 0] = Q6_V_lo_W(vstore); + hvx_vec_store_u(dst + nvec * 64 + 32, (nloe - 32) * sizeof(float), Q6_V_hi_W(vstore)); + } } } @@ -348,13 +460,19 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { const int32_t * pos = (const int32_t *) src1->data; const float * freq_factors = src2 ? (const float *) src2->data : NULL; - uint32_t ir = 0; + const uint32_t i3_start = fastdiv(src0_start_row, &rctx->div_ne2_ne1); + const uint32_t rem = fastmodulo(src0_start_row, ne2 * ne1, &rctx->div_ne2_ne1); + const uint32_t i2_start = fastdiv(rem, &rctx->div_ne1); + const uint32_t i1_start = fastmodulo(rem, ne1, &rctx->div_ne1); + + uint32_t ir = src0_start_row; uint32_t prev_i2 = (uint32_t) -1; - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads - if (ir < src0_start_row) { ir++; i1++; continue; } + for (uint32_t i3 = i3_start; i3 < ne3; i3++) { // batch + const uint32_t i2_init = (i3 == i3_start) ? i2_start : 0; + for (uint32_t i2 = i2_init; i2 < ne2; i2++) { // seq-len + const uint32_t i1_init = (i3 == i3_start && i2 == i2_start) ? i1_start : 0; + for (uint32_t i1 = i1_init; i1 < ne1; ) { // attn-heads if (ir >= src0_end_row) goto done; // Rows in this block @@ -407,9 +525,6 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); } - - // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, - // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } // Skip output DMA transactions from prev block (if any) @@ -489,7 +604,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { // Aligned row sizes for VTCM const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 256); // Calculate spad sizes per thread size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; @@ -546,6 +661,11 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { rctx.src0_nrows = src0_nrows; rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + if (src0_nrows > 0) { + rctx.div_ne2_ne1 = init_fastdiv_values(dst->ne[2] * dst->ne[1]); + rctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + } + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 0def7b408bf..58c54967db0 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -65,6 +65,9 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); @@ -109,6 +112,9 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 40d2d60153a..71fab2cdbcb 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -23,21 +23,26 @@ struct htp_unary_context { // Precomputed values const uint8_t * data_src0; + const uint8_t * data_src1; // weight/scale tensor for RMS_NORM_MUL uint8_t * data_dst; size_t src0_data_row_size; // actual data bytes per row + size_t src1_data_row_size; size_t dst_data_row_size; // actual data bytes per row size_t src0_row_size_aligned; + size_t src1_row_size_aligned; size_t dst_row_size_aligned; size_t src0_spad_half_size; + size_t src1_spad_half_size; size_t dst_spad_half_size; uint32_t block; uint32_t src0_nrows; uint32_t src0_nrows_per_thread; uint32_t nc; + bool broadcast_weight; }; // Convert flat row index to DDR byte offset using the tensor's actual strides. @@ -158,6 +163,71 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_rms_norm_mul_f32(const uint8_t * restrict src, + const uint8_t * restrict weight, + uint8_t * restrict dst, + const int num_elems, + float epsilon) { + const HVX_Vector * restrict v_src = (const HVX_Vector *) src; + const HVX_Vector * restrict v_weight = (const HVX_Vector *) weight; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); + HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + + // Scale and multiply + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[i]); + v_dst[i] = Q6_Vsf_equals_Vqf32(result); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[nvec]); + HVX_Vector res_v = Q6_Vsf_equals_Vqf32(result); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, res_v); + } +} + static void hvx_fast_norm_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, @@ -207,7 +277,7 @@ static void hvx_fast_norm_f32(const uint8_t * restrict src, // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); - HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v)); #pragma unroll(4) for (int i = 0; i < nvec; i++) { @@ -269,6 +339,27 @@ static void rms_norm_f32(const float * restrict src, } } +static void rms_norm_mul_f32(const float * restrict src, + const float * restrict weight, + float * restrict dst, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + const size_t weight_row_size, + int32_t * op_params, + bool broadcast_weight) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + const uint8_t * restrict w_local = (const uint8_t *)weight + (broadcast_weight ? 0 : ir * weight_row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_rms_norm_mul_f32(src_local, w_local, dst_local, row_elems, epsilon); + } +} + static void norm_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -598,12 +689,20 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * t1 = HAP_perf_get_qtimer_count(); const uint8_t * restrict data_src = uctx->data_src0; + const uint8_t * restrict data_src1 = uctx->data_src1; uint8_t * restrict data_dst = uctx->data_dst; + const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL; + const uint32_t nb11 = src1 ? src1->nb[1] : 0; + const uint32_t nb12 = src1 ? src1->nb[2] : 0; + const uint32_t nb13 = src1 ? src1->nb[3] : 0; + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half_size = uctx->src0_spad_half_size; + size_t src1_spad_half_size = uctx->src1_spad_half_size; size_t dst_spad_half_size = uctx->dst_spad_half_size; // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride @@ -624,6 +723,12 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue * dma_queue = octx->ctx->dma[ith]; + // If weight is broadcasted, load it once per thread at the beginning of execution + if (htp_op == HTP_OP_RMS_NORM_MUL && uctx->broadcast_weight) { + dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data, data_src1), uctx->src1_row_size_aligned, 0, uctx->src1_data_row_size, 1); + dma_queue_flush(dma_queue); + } + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) { const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); @@ -636,6 +741,14 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue_push(dma_queue, dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off), src0_row_size_aligned, nb01, src0_data_row_size, block_size); + + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size); + } + ir += block_size; } @@ -644,6 +757,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * float * dst_spad = (float *) dma_queue_pop(dma_queue).src; float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = NULL; + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + src1_spad = (float *) dma_queue_pop(dma_queue).dst; + } // Process block in VTCM switch (htp_op) { @@ -653,6 +770,12 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_RMS_NORM_MUL: + { + const float * w_ptr = uctx->broadcast_weight ? (const float *) src1_spad_data : src1_spad; + rms_norm_mul_f32(src0_spad, w_ptr, dst_spad, block_size, ne0, src0_row_size_aligned, uctx->src1_row_size_aligned, op_params, uctx->broadcast_weight); + } + break; case HTP_OP_SCALE: scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -700,9 +823,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * if (pref_ir < src0_end_row) { const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); - dma_queue_push(dma_queue, - dma_make_ptr(src0_spad, data_src + src0_pref_off), - src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad, data_src + src0_pref_off), + src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad, data_src1 + src1_pref_off), + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size); + } } } ir += block_size; @@ -732,6 +862,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break; + case HTP_OP_RMS_NORM_MUL: + op_type = "rmsnorm-mul-f32"; + break; case HTP_OP_SCALE: op_type = "scale-f32"; break; @@ -777,12 +910,44 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN); + size_t src1_data_row_size = 0; + size_t src1_row_size_aligned = 0; + bool broadcast_weight = false; + const struct htp_tensor * src1 = NULL; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + src1 = octx->src[1]; + src1_data_row_size = src1->ne[0] * sizeof(float); + src1_row_size_aligned = hex_round_up(src1_data_row_size, VLEN); + broadcast_weight = (src1->ne[1] * src1->ne[2] * src1->ne[3] == 1); + } + // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size // Double buffering requires 2x size per buffer - size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); - size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + size_t spad_size_per_row = 0; + size_t vtcm_row_per_thread = 0; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + size_t available_vtcm = octx->ctx->vtcm_size; + size_t src1_spad_total = n_threads * src1_row_size_aligned; + if (available_vtcm > src1_spad_total) { + available_vtcm -= src1_spad_total; + } else { + available_vtcm = 0; + } + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = available_vtcm / (n_threads * spad_size_per_row); + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned + src1_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size) / (n_threads * spad_size_per_row); + } + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + } // Make sure the reserved vtcm size is sufficient if (vtcm_row_per_thread == 0) { @@ -797,8 +962,29 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + octx->src1_spad.size_per_thread = src1_row_size_aligned; + } else { + octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread * 2; + } + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } else { + octx->src1_spad.size = 0; + octx->src1_spad.size_per_thread = 0; + } + octx->src0_spad.data = octx->ctx->vtcm_base; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + if (octx->op == HTP_OP_RMS_NORM_MUL) { + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + } else { + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } + + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], @@ -811,19 +997,24 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .src0_nrows = src0_nrows, .data_src0 = (const uint8_t *)src0->data, + .data_src1 = (octx->op == HTP_OP_RMS_NORM_MUL) ? (const uint8_t *)src1->data : NULL, .data_dst = (uint8_t *)dst->data, .src0_data_row_size = src0_data_row_size, + .src1_data_row_size = src1_data_row_size, .dst_data_row_size = dst_data_row_size, .src0_row_size_aligned = src0_row_size_aligned, + .src1_row_size_aligned = src1_row_size_aligned, .dst_row_size_aligned = dst_row_size_aligned, .src0_spad_half_size = octx->src0_spad.size_per_thread / 2, + .src1_spad_half_size = (octx->op == HTP_OP_RMS_NORM_MUL) ? (octx->src1_spad.size_per_thread / (broadcast_weight ? 1 : 2)) : 0, .dst_spad_half_size = octx->dst_spad.size_per_thread / 2, .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned, .nc = src0->ne[0], + .broadcast_weight = broadcast_weight, }; worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads); diff --git a/ggml/src/ggml-hexagon/op-desc.h b/ggml/src/ggml-hexagon/op-desc.h deleted file mode 100644 index a1e8ddd8b97..00000000000 --- a/ggml/src/ggml-hexagon/op-desc.h +++ /dev/null @@ -1,153 +0,0 @@ -#ifndef OP_DESC_H -#define OP_DESC_H - -#define GGML_COMMON_IMPL_CPP -#include "ggml-backend-impl.h" -#include "ggml-common.h" - -#include <string> -#include <stdio.h> - -struct op_desc { - char strides[64 * GGML_MAX_SRC]; - char dims[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - int format_tensor_dims(char * str, const struct ggml_tensor * t) { - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); - } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); - } - } - - void format_op_dims(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_dims(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_dims(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_dims(self, t); - - p += sprintf(p, "%s", self); - } - - int format_tensor_strides(char * str, const struct ggml_tensor * t) { - const char * c = ggml_is_contiguous(t) ? "" : "!"; - - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); - } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); - } - } - - void format_op_strides(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_strides(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_strides(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_strides(self, t); - - p += sprintf(p, "%s", self); - } - - void format_op_types(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", ggml_type_name(t->src[0]->type)); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(t->src[i]->type)); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", ggml_type_name(t->type)); - } - - const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { - return ggml_backend_buffer_name(t->buffer); - } - return "NONE"; - } - - void format_op_buffs(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", tensor_buff_name(t->src[0])); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", tensor_buff_name(t->src[i])); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", tensor_buff_name(t)); - } - - void format_op_names(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", t->src[0]->name); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", t->src[i]->name); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", t->name); - } - - void format(const ggml_tensor * op) { - format_op_dims(dims, op); - format_op_strides(strides, op); - format_op_types(types, op); - format_op_buffs(buffs, op); - format_op_names(names, op); - } - - op_desc() {} - op_desc(const ggml_tensor * op) { format(op); } -}; - -#endif // OP_DESC_H diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba006d9b31a..4f4f073cb61 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,8 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G - // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = op->src[5]->ne[1]; + // state is src[5], 4D [S_v, S_v, H_v, n_seqs] (s0 only); K is op param 0. + const int K = ggml_get_op_params_i32(op, 0); const int nsg = op->src[2]->ne[0]/32; @@ -1732,14 +1732,24 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); + GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + char base[256]; char name[256]; - snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + if (KH*KW <= 1024) { + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + } else { + snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); + } snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 1f212a92f98..4a3ebb5569d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -215,6 +215,30 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); // device // +enum ggml_metal_device_id { + GGML_METAL_DEVICE_GENERIC = 0, + + GGML_METAL_DEVICE_M1, + GGML_METAL_DEVICE_M1_PRO, + GGML_METAL_DEVICE_M1_MAX, + GGML_METAL_DEVICE_M1_ULTRA, + GGML_METAL_DEVICE_M2, + GGML_METAL_DEVICE_M2_PRO, + GGML_METAL_DEVICE_M2_MAX, + GGML_METAL_DEVICE_M2_ULTRA, + GGML_METAL_DEVICE_M3, + GGML_METAL_DEVICE_M3_PRO, + GGML_METAL_DEVICE_M3_MAX, + GGML_METAL_DEVICE_M3_ULTRA, + GGML_METAL_DEVICE_M4, + GGML_METAL_DEVICE_M4_PRO, + GGML_METAL_DEVICE_M4_MAX, + GGML_METAL_DEVICE_M5, + GGML_METAL_DEVICE_M5_PRO, + GGML_METAL_DEVICE_M5_MAX, + GGML_METAL_DEVICE_M5_ULTRA, +}; + struct ggml_metal_device_props { int device; char name[128]; @@ -234,6 +258,8 @@ struct ggml_metal_device_props { bool supports_gpu_family_apple7; + enum ggml_metal_device_id device_id; + int op_offload_min_batch_size; }; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 780dfe81bb3..d583bd6efc0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -547,6 +547,8 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { // number of seconds since the last graph computation // keep the residency sets wired for that amount of time to avoid being collected by the OS int keep_alive_s; + int loops_per_s; + int time_per_loop_ms; // background heartbeat thread to keep the residency sets alive atomic_bool d_stop; @@ -573,10 +575,13 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { res->keep_alive_s = 3*60; } + res->time_per_loop_ms = 5; + res->loops_per_s = 1000/res->time_per_loop_ms; + GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s); atomic_store_explicit(&res->d_stop, false, memory_order_relaxed); - atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&res->d_loop, res->loops_per_s*res->keep_alive_s, memory_order_relaxed); res->d_group = dispatch_group_create(); @@ -599,8 +604,7 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { [res->lock unlock]; } - // half a second - usleep(500 * 1000); + usleep(res->time_per_loop_ms * 1000); } } #endif @@ -628,6 +632,50 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { free(rsets); } +static enum ggml_metal_device_id ggml_metal_device_id_parse(const char * name) { + if (!name) { + return GGML_METAL_DEVICE_GENERIC; + } + + static const char prefix[] = "Apple "; + if (strncmp(name, prefix, sizeof(prefix) - 1) != 0) { + return GGML_METAL_DEVICE_GENERIC; + } + const char * suffix = name + sizeof(prefix) - 1; + + static const struct { + const char * name; + enum ggml_metal_device_id id; + } table[] = { + {"M1", GGML_METAL_DEVICE_M1}, + {"M1 Pro", GGML_METAL_DEVICE_M1_PRO}, + {"M1 Max", GGML_METAL_DEVICE_M1_MAX}, + {"M1 Ultra", GGML_METAL_DEVICE_M1_ULTRA}, + {"M2", GGML_METAL_DEVICE_M2}, + {"M2 Pro", GGML_METAL_DEVICE_M2_PRO}, + {"M2 Max", GGML_METAL_DEVICE_M2_MAX}, + {"M2 Ultra", GGML_METAL_DEVICE_M2_ULTRA}, + {"M3", GGML_METAL_DEVICE_M3}, + {"M3 Pro", GGML_METAL_DEVICE_M3_PRO}, + {"M3 Max", GGML_METAL_DEVICE_M3_MAX}, + {"M3 Ultra", GGML_METAL_DEVICE_M3_ULTRA}, + {"M4", GGML_METAL_DEVICE_M4}, + {"M4 Pro", GGML_METAL_DEVICE_M4_PRO}, + {"M4 Max", GGML_METAL_DEVICE_M4_MAX}, + {"M5", GGML_METAL_DEVICE_M5}, + {"M5 Pro", GGML_METAL_DEVICE_M5_PRO}, + {"M5 Max", GGML_METAL_DEVICE_M5_MAX}, + {"M5 Ultra", GGML_METAL_DEVICE_M5_ULTRA}, + }; + + for (size_t i = 0; i < sizeof(table)/sizeof(table[0]); ++i) { + if (strcmp(suffix, table[i].name) == 0) { + return table[i].id; + } + } + return GGML_METAL_DEVICE_GENERIC; +} + ggml_metal_device_t ggml_metal_device_init(int device) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); @@ -795,6 +843,8 @@ ggml_metal_device_t ggml_metal_device_init(int device) { dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.device_id = ggml_metal_device_id_parse([[dev->mtl_device name] UTF8String]); + dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; @@ -933,7 +983,7 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { return; } - atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&dev->rsets->d_loop, dev->rsets->loops_per_s*dev->rsets->keep_alive_s, memory_order_relaxed); } struct ggml_metal_event { @@ -1061,7 +1111,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_1(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } @@ -1070,8 +1120,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: return true; + case GGML_OP_CONCAT: + { + // kernel_concat copies one float-sized value per element. + // Other scalar types need a type-generic copy kernel first. + const enum ggml_type src0_type = op->src[0]->type; + const enum ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32); + } case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 206af227a2c..e2ce56e9e28 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3635,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); - GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); - const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + } else { + const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N); + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1); + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e772664ba91..0aea68455fb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1421,7 +1421,8 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>; -kernel void kernel_reglu_f32( +template<typename T> +kernel void kernel_reglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1429,19 +1430,25 @@ kernel void kernel_reglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; const float x1 = src1_row[i0]; - dst_row[i0] = x0*x1*(x0 > 0.0f); + dst_row[i0] = (T)(x0*x1*(x0 > 0.0f)); } } -kernel void kernel_geglu_f32( +typedef decltype(kernel_reglu<float>) kernel_reglu_t; + +template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>; +template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>; + +template<typename T> +kernel void kernel_geglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1449,9 +1456,9 @@ kernel void kernel_geglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1459,11 +1466,17 @@ kernel void kernel_geglu_f32( const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); - dst_row[i0] = gelu*x1; + dst_row[i0] = (T)(gelu*x1); } } -kernel void kernel_swiglu_f32( +typedef decltype(kernel_geglu<float>) kernel_geglu_t; + +template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>; +template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>; + +template<typename T> +kernel void kernel_swiglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1471,9 +1484,9 @@ kernel void kernel_swiglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1481,11 +1494,17 @@ kernel void kernel_swiglu_f32( const float silu = x0 / (1.0f + exp(-x0)); - dst_row[i0] = silu*x1; + dst_row[i0] = (T)(silu*x1); } } -kernel void kernel_swiglu_oai_f32( +typedef decltype(kernel_swiglu<float>) kernel_swiglu_t; + +template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>; +template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>; + +template<typename T> +kernel void kernel_swiglu_oai( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1493,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { float x0 = src0_row[i0]; @@ -1507,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32( float out_glu = x0 / (1.0f + exp(-x0 * args.alpha)); out_glu = out_glu * (1.0f + x1); - dst_row[i0] = out_glu; + dst_row[i0] = (T)out_glu; } } -kernel void kernel_geglu_erf_f32( +typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t; + +template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>; +template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>; + +template<typename T> +kernel void kernel_geglu_erf( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1519,9 +1544,9 @@ kernel void kernel_geglu_erf_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1529,11 +1554,17 @@ kernel void kernel_geglu_erf_f32( const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV)); - dst_row[i0] = gelu_erf*x1; + dst_row[i0] = (T)(gelu_erf*x1); } } -kernel void kernel_geglu_quick_f32( +typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t; + +template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>; +template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>; + +template<typename T> +kernel void kernel_geglu_quick( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1541,9 +1572,9 @@ kernel void kernel_geglu_quick_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1551,10 +1582,15 @@ kernel void kernel_geglu_quick_f32( const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); - dst_row[i0] = gelu_quick*x1; + dst_row[i0] = (T)(gelu_quick*x1); } } +typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t; + +template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>; +template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>; + kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, @@ -2563,9 +2599,9 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + // input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2584,9 +2620,8 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = (int)args.ne22 - (int)K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. // output state base offset: after attention scores const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; @@ -2644,7 +2679,7 @@ kernel void kernel_gated_delta_net_impl( g_ptr += args.ne21*G; if (K > 1) { - const int target_slot = (int)t - shift; + const int target_slot = (int)args.ne22 - 1 - (int)t; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { @@ -4696,59 +4731,59 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; -// TODO: obsolete -- remove -//typedef void (im2col_ext_t)( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]); -// -//template <typename T> -//kernel void kernel_im2col_ext( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] -// const int64_t KHW = (int64_t)args.KHW; -// -// const int64_t d = tgpig[0] / args.CHW; -// const int64_t chw = tgpig[0] % args.CHW; -// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) -// const int64_t HW = tgpig[0] % KHW; -// -// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; -// if (tpitg_0 >= args.N) { -// return; -// } -// -// const int64_t tpitg_1 = HW / args.KW; -// const int64_t tpitg_2 = HW % args.KW; -// -// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; -// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; -// -// const int64_t offset_dst = -// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + -// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); -// -// device T * pdst = (device T *) (dst); -// -// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { -// pdst[offset_dst] = 0.0f; -// } else { -// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; -// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; -// } -//} -// -//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; -//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; +// TODO: optimize +typedef void (im2col_ext_t)( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template <typename T> +kernel void kernel_im2col_ext( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = (int64_t)args.KHW; + + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= args.N) { + return; + } + + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; + + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; template <typename TK> kernel void kernel_conv_2d( diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index f75d089b574..82ce61d72c6 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -87,6 +87,10 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 mul_mv_q4_k_f32_flat + mul_mv_q5_0_f32 + mul_mv_q5_0_f32_flat + mul_mv_q5_1_f32 + mul_mv_q5_1_f32_flat mul_mv_q5_k_f32 mul_mv_q5_k_f32_flat mul_mv_q6_k_f32 @@ -126,6 +130,8 @@ set(GGML_OPENCL_KERNELS mul_mm_f16_f32_l4_lm mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm + mul_mm_q5_0_f32_l4_lm + mul_mm_q5_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm mul_mm_iq4_nl_f32_l4_lm mul_mm_q4_k_f32_l4_lm @@ -136,6 +142,10 @@ set(GGML_OPENCL_KERNELS gemm_noshuffle_q4_0_f32 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 + gemv_noshuffle_q5_0_f32 + gemm_noshuffle_q5_0_f32 + gemv_noshuffle_q5_1_f32 + gemm_noshuffle_q5_1_f32 gemv_noshuffle_iq4_nl_f32 gemm_noshuffle_iq4_nl_f32 gemv_noshuffle_q8_0_f32 @@ -164,6 +174,7 @@ set(GGML_OPENCL_KERNELS sqr sqrt ssm_conv + gated_delta_net sub sum_rows cumsum diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 42286435bc6..ca2002424df 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -379,6 +379,8 @@ struct ggml_backend_opencl_device_context { GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + std::regex *opfilter = nullptr; // regex of ops to not claim + std::string opfilter_str = ""; // regex string for opfilter size_t global_mem_size = 0; }; @@ -412,10 +414,9 @@ struct ggml_backend_opencl_context { size_t max_workgroup_size; bool fp16_support; bool has_vector_subgroup_broadcast; + bool has_qcom_subgroup_shuffle = false; // cl_qcom_subgroup_shuffle bool disable_fusion; - std::regex *opfilter = nullptr; // regex of ops to not claim - bool adreno_has_large_buffer; bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; @@ -427,6 +428,8 @@ struct ggml_backend_opencl_context { size_t image2d_max_width; size_t image2d_max_height; + cl_device_svm_capabilities svm_caps; + cl_context context; cl_command_queue queue; @@ -555,7 +558,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; - cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_f32_f32_pack, kernel_cpy_i32_i32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; cl_kernel kernel_mul_mat_f16_f32_1row; @@ -573,7 +576,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; + cl_kernel kernel_convert_block_q5_0, kernel_restore_block_q5_0; cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; + cl_kernel kernel_convert_block_q5_1, kernel_restore_block_q5_1; cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; @@ -582,11 +587,16 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; + cl_kernel kernel_convert_bf16_to_f16, kernel_convert_f16_to_bf16; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q5_0_noshuffle; + cl_kernel kernel_restore_block_q5_0_noshuffle; + cl_kernel kernel_convert_block_q5_1_noshuffle; + cl_kernel kernel_restore_block_q5_1_noshuffle; cl_kernel kernel_convert_block_q4_K_noshuffle; cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; @@ -600,6 +610,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; + cl_kernel kernel_mul_mv_q5_0_f32; + cl_kernel kernel_mul_mv_q5_0_f32_flat; + cl_kernel kernel_mul_mv_q5_1_f32; + cl_kernel kernel_mul_mv_q5_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q4_K_f32_flat; cl_kernel kernel_mul_mv_q5_K_f32; @@ -629,11 +643,15 @@ struct ggml_backend_opencl_context { cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; - cl_kernel kernel_concat_f32; + cl_kernel kernel_concat_f32, kernel_concat_f32_pack; cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; + // [size_idx][kda][tgpp] where size_idx: 0=S_V=16, 1=32, 2=64, 3=128; kda: 0 or 1. + // tgpp 0 = TG variant (COLS_PER_LANE_GROUP=1), tgpp 1 = prefill variant (COLS_PER_LANE_GROUP=4). + cl_kernel kernel_gated_delta_net_f32[4][2][2] = {}; + cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; @@ -654,6 +672,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_l4_lm; cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; @@ -813,6 +833,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q6_K_f32; cl_kernel kernel_gemv_noshuffle_q5_k_f32; cl_kernel kernel_gemm_noshuffle_q5_k_f32; + cl_kernel kernel_gemv_noshuffle_q5_0_f32; + cl_kernel kernel_gemm_noshuffle_q5_0_f32; + cl_kernel kernel_gemv_noshuffle_q5_1_f32; + cl_kernel kernel_gemm_noshuffle_q5_1_f32; cl_kernel kernel_gemv_noshuffle_iq4_nl_f32; cl_kernel kernel_gemm_noshuffle_iq4_nl_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -837,16 +861,16 @@ static std::vector<ggml_backend_device> g_ggml_backend_opencl_devices; static std::vector<std::unique_ptr<ggml_backend_opencl_device_context>> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { - std::ifstream ifs(path); - if (!ifs) { - return ""; - } - std::string text; - ifs.seekg(0, std::ios::end); - text.resize(ifs.tellg()); - ifs.seekg(0, std::ios::beg); - ifs.read(&text[0], text.size()); - return text; + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + std::string text; + ifs.seekg(0, std::ios::end); + text.resize(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + ifs.read(&text[0], text.size()); + return text; } static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { @@ -1105,6 +1129,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32_pack = clCreateKernel(prog, "kernel_cpy_f32_f32_pack", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err)); GGML_LOG_CONT("."); } @@ -1133,8 +1158,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); @@ -1168,6 +1201,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_bf16_to_f16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_bf16_to_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_f16_to_bf16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_f16_to_bf16", &err), err)); GGML_LOG_CONT("."); } @@ -1475,6 +1510,74 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // mul_mv_q5_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_0_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_0_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_0_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_0_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_1_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_1_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_1_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_1_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q5_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1825,6 +1928,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // mul_mm_q5_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_0_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q5_1_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_1_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_1_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2463,12 +2598,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); if (backend_ctx->program_upscale) { - cl_int err_bilinear; - backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); - if (err_bilinear != CL_SUCCESS) { + cl_int err_bilinear; + backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); + if (err_bilinear != CL_SUCCESS) { GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); backend_ctx->kernel_upscale_bilinear = nullptr; - } + } } else { backend_ctx->kernel_upscale_bilinear = nullptr; } @@ -2493,6 +2628,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_concat_f32_pack = clCreateKernel(prog, "kernel_concat_f32_pack", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2538,8 +2674,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } - // conv2d - { + // conv2d + { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "conv2d.cl.h" @@ -2597,6 +2733,86 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // gated_delta_net: one kernel per (S_V, KDA, tgpp) triple. + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gated_delta_net.cl.h" + }; + #else + const std::string kernel_src = read_file("gated_delta_net.cl"); + #endif + + const int gdn_sizes[4] = { 16, 32, 64, 128 }; + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); + exit(1); + } + + for (int si = 0; si < 4; si++) { + const int S_V = gdn_sizes[si]; + + // MUST match the dispatcher heuristic in ggml_cl_gated_delta_net exactly. + int lanes_per_column; + if (S_V >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min(S_V, sg_size); + } + + // Round LANES_PER_COLUMN down until it is: + // * power-of-two + // * divides both S_V and sg_size + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (S_V % lanes_per_column) != 0 || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; + } + + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT((S_V % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); + + const bool is_partial_reduce = (lanes_per_column != 1) && (lanes_per_column < sg_size); + int use_qcom_shuffle = 0; + if (is_partial_reduce) { + if (backend_ctx->has_qcom_subgroup_shuffle) { + use_qcom_shuffle = 1; + } + } + for (int kda = 0; kda < 2; kda++) { + for (int tgpp = 0; tgpp < 2; tgpp++) { + const int cpl = (tgpp == 0) ? 1 : 4; + const int spw = (tgpp == 0) ? 1 : 1; + + std::string opts = compile_opts; + opts += " -DS_V=" + std::to_string(S_V); + opts += " -DKDA=" + std::to_string(kda); + opts += " -DSUBGROUP_SIZE=" + std::to_string(sg_size); + opts += " -DLANES_PER_COLUMN=" + std::to_string(lanes_per_column); + opts += " -DCOLS_PER_LANE_GROUP=" + std::to_string(cpl); + opts += " -DUSE_QCOM_SUBGROUP_SHUFFLE=" + std::to_string(use_qcom_shuffle); + + // Since spw=1 is found to be optimal, SUBGROUPS_PER_WG > 1 code in + // the kernel is removed. If you want to experiment with spw > 1, + // Please remember to implement code to handle it. + opts += " -DSUBGROUPS_PER_WG=" + std::to_string(spw); + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), opts); + + CL_CHECK((backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp] = + clCreateKernel(prog, "kernel_gated_delta_net", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + } + } + } + GGML_LOG_CONT("."); + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2827,7 +3043,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q4_1_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl"); #endif @@ -2861,12 +3077,86 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // gemm_noshuffle_q5_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_0_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_0_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_0_f32.cl"); +#endif + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q5_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_1_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_1_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_1_f32.cl"); +#endif + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemm_noshuffle_iq4_nl_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_iq4_nl_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl"); #endif @@ -2905,7 +3195,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q8_0_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q8_0_f32.cl"); #endif @@ -2946,7 +3236,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q4_k_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl"); #endif @@ -3646,6 +3936,68 @@ static std::vector<ggml_backend_device> ggml_opencl_probe_devices(ggml_backend_r return found_devices; } +static void ggml_opencl_print_backend_info(ggml_backend_opencl_device_context * dev_ctx) { + GGML_ASSERT(dev_ctx); + GGML_ASSERT(dev_ctx->backend_ctx); + + auto * backend_ctx = dev_ctx->backend_ctx; + + GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", + backend_ctx->driver_version.c_str()); + GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", + backend_ctx->fp16_support ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", + backend_ctx->alignment); + GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", + backend_ctx->global_mem_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", + backend_ctx->max_alloc_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", + backend_ctx->image_max_buffer_size); + GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", + backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); + GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", + backend_ctx->max_workgroup_size); + GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", + backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); + + // Print out configurations +#ifdef GGML_OPENCL_SOA_Q + GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); +#endif // GGML_OPENCL_SOA_Q + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); + if (backend_ctx->adreno_xmem_gemm_enabled) { + GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM enabled (temporary weight prepack)\n"); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + if (backend_ctx->adreno_use_large_buffer) { + if (!backend_ctx->adreno_has_large_buffer) { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); + backend_ctx->adreno_use_large_buffer = false; + } else { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); + } + } + + if (dev_ctx->opfilter) { + // for information only, the actual regex object is created in ggml_opencl_is_device_supported + GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", dev_ctx->opfilter_str.c_str()); + } +} + // check if device should be accepted static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { GGML_ASSERT(dev); @@ -3714,6 +4066,13 @@ static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { } clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + + const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); + if (str_opfilter) { + dev_ctx->opfilter_str = str_opfilter; + dev_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); + } + return true; } @@ -3765,15 +4124,12 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { char *driver_version = (char *)alloca(driver_version_str_size + 1); clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); driver_version[driver_version_str_size] = '\0'; - GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); backend_ctx->driver_version = driver_version; backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); backend_ctx->has_vector_subgroup_broadcast = (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) || (backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17); - GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); size_t ext_str_size; clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); @@ -3781,9 +4137,13 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // check support for qcom_subgroup_shuffle + if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { + backend_ctx->has_qcom_subgroup_shuffle = true; + } + // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; - GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; @@ -3792,35 +4152,15 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); GGML_ASSERT(base_align_in_bits % 8u == 0); backend_ctx->alignment = base_align_in_bits / 8u; - GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); backend_ctx->global_mem_size = dev_ctx->global_mem_size; - GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); - - clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); - GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); - - clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL); - GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size); - clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL); - clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL); - GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); - - clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); - GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); - - // Check SVM. - cl_device_svm_capabilities svm_caps; - CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); - GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", - svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &backend_ctx->svm_caps, 0)); if (opencl_c_version.major >= 3) { // Assume it is not available for 3.0, since it is optional in 3.0. @@ -3836,36 +4176,15 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { backend_ctx->non_uniform_workgroups = true; } - // Print out configurations -#ifdef GGML_OPENCL_SOA_Q - GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); -#endif // GGML_OPENCL_SOA_Q - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); -#endif // GGML_OPENCL_USE_ADRENO_KERNELS - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use Adreno xmem GEMM backend_ctx->adreno_xmem_gemm_enabled = getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; - if (getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr) { - GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM %s\n", - backend_ctx->adreno_xmem_gemm_enabled ? - "enabled (temporary weight prepack)" : "requested but unsupported by this driver"); - } -#endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // determine whether to use large buffer for Adreno backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; - if (backend_ctx->adreno_use_large_buffer) { - if (!backend_ctx->adreno_has_large_buffer) { - GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); - backend_ctx->adreno_use_large_buffer = false; - } else { - GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); - } - } cl_int err; @@ -3915,12 +4234,6 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; - const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); - if (str_opfilter) { - backend_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); - GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", str_opfilter); - } - dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -4725,12 +5038,27 @@ inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backen return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 } +static inline bool use_flat_gemv_for_large_m_q4_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool use_flat_gemv_for_large_m_q6_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // q6_K flat gemv is worse for smaller K; 2048 seems to be a reasonable threshold. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[0] >= 2048 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; // reject ops that match the opfilter regex - if (backend_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *backend_ctx->opfilter)) { + if (dev_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *dev_ctx->opfilter)) { return false; } @@ -4832,17 +5160,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_EXP: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_SOFTPLUS: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; default: return false; } @@ -4891,6 +5219,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_SSM_CONV: return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_GATED_DELTA_NET: + { + // Match the Vulkan backend: only F32 -> F32, S_v in {16, 32, 64, 128}. + if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) { + return false; + } + const int64_t S_v = op->src[2]->ne[0]; + return S_v == 16 || S_v == 32 || S_v == 64 || S_v == 128; + } case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -4900,9 +5237,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_MUL_MAT: if (op->src[0]->type == GGML_TYPE_F16) { return true; + } else if (op->src[0]->type == GGML_TYPE_BF16) { + return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || op->src[0]->type == GGML_TYPE_Q5_1 || op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_IQ4_NL || op->src[0]->type == GGML_TYPE_Q4_K || @@ -5853,10 +6193,60 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - return; - } - if (tensor->type == GGML_TYPE_Q5_1) { - ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; + } + if (tensor->type == GGML_TYPE_Q5_1) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); // Allocate the new extra and create aliases from the original. @@ -5954,6 +6344,60 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + // Transpose m as ushort + transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -6323,7 +6767,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; } #else @@ -6351,7 +6795,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6550,9 +6994,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Adreno MoE Q6_K kernel needs special transposed layout if (use_adreno_moe_kernels(backend_ctx, tensor)) { @@ -6586,6 +7027,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; @@ -6651,7 +7095,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS kernel = backend_ctx->kernel_convert_block_q6_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; } #else @@ -6684,7 +7128,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { cl_int M = tensor->ne[1]; // ne01 cl_int K = tensor->ne[0]; // ne00 @@ -6709,6 +7153,40 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, } #endif // GGML_OPENCL_SOA_Q + // convert bf16 to f16 and store as f16 in device buffer + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_dst = (extra->offset + offset) / sizeof(ggml_fp16_t); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, + size, const_cast<void *>(data), &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_bf16_to_f16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + CL_CHECK(clReleaseEvent(evt)); + + return; + } + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; GGML_ASSERT(extra); @@ -6976,9 +7454,72 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // TODO: normal q5_0 - (void) extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); return; } if (tensor->type == GGML_TYPE_Q5_1) { @@ -7018,9 +7559,78 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_trans_m; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_m.allocate(backend_ctx->context, size_m); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // Transpose back: from col-major to row-major + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_m.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // TODO: normal q5_1 - (void) extra; + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -7251,7 +7861,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -7434,9 +8044,6 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, tensor)) { cl_int err; @@ -7446,6 +8053,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; @@ -7472,7 +8082,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; static ggml_cl_buffer buf_trans_s; @@ -7557,6 +8167,41 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, } #endif // GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_src = (extra->offset + tensor->view_offs + offset) / sizeof(ggml_fp16_t); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_f16_to_bf16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); + + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, 0, size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + + return; + } + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; CL_CHECK(clEnqueueReadBuffer( @@ -7719,6 +8364,8 @@ static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, co /* .context = */ backend_ctx, }; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + ggml_opencl_print_backend_info(dev_ctx); return backend; GGML_UNUSED(params); @@ -8044,6 +8691,7 @@ static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor kernel = backend_ctx->kernel_cpy_f32_f32; break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: // stored as f16 on device kernel = backend_ctx->kernel_cpy_f16_f16; break; default: @@ -8151,7 +8799,14 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c nth *= 2; } - size_t global_work_size[] = {(size_t)ne10*nth, (size_t)ne11, (size_t)ne12}; + int nchunks = 1; + if (src0->type == GGML_TYPE_F32) { + const int chunk_target = nth * 4; + nchunks = (ne00 + chunk_target - 1) / chunk_target; + nchunks = MAX(1, MIN(nchunks, 64)); + } + + size_t global_work_size[] = {(size_t)ne10*nth*nchunks, (size_t)ne11, (size_t)ne12}; size_t local_work_size[] = {(size_t)nth, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); @@ -10555,7 +11210,7 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; - if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { + if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { local_work_size_ptr = nullptr; } @@ -10727,7 +11382,9 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con int nth = MIN(64, ne0); - cl_kernel kernel = backend_ctx->kernel_concat_f32; + const bool concat_pack = (dim == 0 && ne0 < 32); + cl_kernel kernel = concat_pack ? backend_ctx->kernel_concat_f32_pack + : backend_ctx->kernel_concat_f32; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -10754,10 +11411,28 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim)); - size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + if (concat_pack) { + // packed kernel needs the dst dims to unflatten its 1-D row index. + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &ne3)); + + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne0, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; + const int nrows = ne1*ne2*ne3; + const int nwg = (nrows + rpw - 1) / rpw; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); + } else { + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { @@ -11004,7 +11679,8 @@ static bool ggml_cl_can_use_adreno_xmem_gemm_f16_f32( if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { return false; } - if (src0->type != GGML_TYPE_F16 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_BF16) || + src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { return false; } if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { @@ -11742,15 +12418,376 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); // gemm - kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32; + kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q5_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem qs_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_0->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_0_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(qs_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_0_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q5_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem qs_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_1->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(qs_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_1_f32; int padded_N = N + padding; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &d_sub_buf)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); @@ -11765,6 +12802,7 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); CL_CHECK(clReleaseMemObject(b_img)); CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); } #else GGML_UNUSED(backend); @@ -12722,7 +13760,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0->type; + // bf16 is stored as f16 on device + const enum ggml_type src0t = (src0->type == GGML_TYPE_BF16) ? GGML_TYPE_F16 : src0->type; const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -12738,6 +13777,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; @@ -12809,6 +13850,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q5_0 x fp32 + if (src0t == GGML_TYPE_Q5_0 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_0_f32_adreno(backend, src0, src1, dst); + return; + } + + // q5_1 x fp32 + if (src0t == GGML_TYPE_Q5_1 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_1_f32_adreno(backend, src0, src1, dst); + return; + } + // iq4_nl x fp32 if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); @@ -12823,13 +13876,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } // q4_k x fp32 - if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q4_K(src0)) { ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); return; } // q6_K x fp32 - if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q6_K(src0)) { ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); return; } @@ -13073,6 +14126,93 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q5_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q5_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q8_0: { if (ne11 < 32) { break; @@ -13609,6 +14749,137 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q5_0: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_0_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_0_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_Q5_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; @@ -13913,7 +15184,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (backend_ctx->gpu_family == ADRENO) { nth0 = 64; nth1 = 2; - ndst = 4; + ndst = 16; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } @@ -14049,6 +15320,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_Q2_K) { @@ -14278,6 +15551,8 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int ne1 = dst->ne[1]; const int ne2 = dst->ne[2]; + GGML_UNUSED(ne2); + const int r2 = ne12/ne02; const int r3 = ne13/ne03; const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows @@ -14292,6 +15567,8 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int n_tile_size = 32; const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + GGML_UNUSED(max_post_router_tile); + cl_kernel kernel; // subgroup mat vec @@ -16004,7 +17281,8 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const kernel = backend_ctx->kernel_cpy_f32_f16; break; case GGML_TYPE_F32: - kernel = backend_ctx->kernel_cpy_f32_f32; + kernel = ne00 < 32 ? backend_ctx->kernel_cpy_f32_f32_pack + : backend_ctx->kernel_cpy_f32_f32; break; default: GGML_ASSERT(false && "not implemented"); @@ -16056,12 +17334,27 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); - const int nth = MIN(64, ne00); + if (kernel == backend_ctx->kernel_cpy_f32_f32_pack) { + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne00, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; // <= base <= maxwg + const int nrows = ne01*ne02*ne03; + const int nwg = (nrows + rpw - 1) / rpw; - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, src1); + } else { + const int nth = MIN(64, ne00); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + } } static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -17052,6 +18345,185 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) { + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_k = dst->src[1]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_g = dst->src[3]; + const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; + + GGML_ASSERT(src_q && src_q->extra); + GGML_ASSERT(src_k && src_k->extra); + GGML_ASSERT(src_v && src_v->extra); + GGML_ASSERT(src_g && src_g->extra); + GGML_ASSERT(src_beta && src_beta->extra); + GGML_ASSERT(src_state && src_state->extra); + + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) backend->context; + + const cl_uint S_v = (cl_uint) src_v->ne[0]; + const cl_uint H_v = (cl_uint) src_v->ne[1]; + const cl_uint n_tokens = (cl_uint) src_v->ne[2]; + const cl_uint n_seqs = (cl_uint) src_v->ne[3]; + const cl_uint K = (cl_uint) ggml_get_op_params_i32(dst, 0); + + int si; + switch (S_v) { + case 16: si = 0; break; + case 32: si = 1; break; + case 64: si = 2; break; + case 128: si = 3; break; + default: + GGML_ASSERT(false && "ggml_cl_gated_delta_net: unsupported S_v"); + } + + const int kda = (src_g->ne[0] == (int64_t) S_v) ? 1 : 0; + + // TODO: Optimize when S_v!=128. Not necessary for now as Qwen3.5/6 are all S_v=128 + // token generation mode (tgpp=0): + // process 1 token at a time, so columns per lane (cpl) == 1 + // prompt processing mode (tgpp=1): + // cpl=4 to process 4 tokens for single-token. 4 is chosen for Adreno 750 as per + // work-item/thread has at most 128 registers. + // All Qwen3.5/6 models are S_v == 128, so LANES_PER_COLUMN == 8 + // such that ROWS_PER_LANE = 128/8 = 16 + // Variables in the kernel: + // k_reg, q_reg, g_exp are all 16 floats + // s_shard has cpl*ROWS_PER_LANE = 4*16 = 64 floats + // Total 112 registers used. + // subgroups_per_workgroup (spw) can be set to 1,2,4,8,16 for tg and 1,2,4 for pp + // for S_v=128. + // Empirically found that when spw=1, we get the best performance for both tg and pp + const int tgpp = (n_tokens == 1) ? 0 : 1; + const int cpl = (tgpp == 0) ? 1 : 4; + // spw needs adjustment when S_v != 128 + const int spw = (tgpp == 0) ? 1 : 1; + + cl_kernel kernel = backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp]; + GGML_ASSERT(kernel != nullptr); + + const cl_uint s_off = S_v * H_v * n_tokens * n_seqs; + + const cl_uint sq1 = (cl_uint)(src_q->nb[1] / sizeof(float)); + const cl_uint sq2 = (cl_uint)(src_q->nb[2] / sizeof(float)); + const cl_uint sq3 = (cl_uint)(src_q->nb[3] / sizeof(float)); + const cl_uint sv1 = (cl_uint)(src_v->nb[1] / sizeof(float)); + const cl_uint sv2 = (cl_uint)(src_v->nb[2] / sizeof(float)); + const cl_uint sv3 = (cl_uint)(src_v->nb[3] / sizeof(float)); + const cl_uint sb1 = (cl_uint)(src_beta->nb[1] / sizeof(float)); + const cl_uint sb2 = (cl_uint)(src_beta->nb[2] / sizeof(float)); + const cl_uint sb3 = (cl_uint)(src_beta->nb[3] / sizeof(float)); + + const cl_uint H_k = (cl_uint) src_q->ne[1]; + const cl_uint rq3 = (cl_uint)(src_v->ne[3] / src_q->ne[3]); + + const float scale = 1.0f / sqrtf((float) S_v); + + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *) src_q->extra; + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *) src_k->extra; + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *) src_v->extra; + ggml_tensor_extra_cl * extra_g = (ggml_tensor_extra_cl *) src_g->extra; + ggml_tensor_extra_cl * extra_beta = (ggml_tensor_extra_cl *) src_beta->extra; + ggml_tensor_extra_cl * extra_state = (ggml_tensor_extra_cl *) src_state->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *) dst->extra; + + const cl_ulong off_q = extra_q->offset + src_q->view_offs; + const cl_ulong off_k = extra_k->offset + src_k->view_offs; + const cl_ulong off_v = extra_v->offset + src_v->view_offs; + const cl_ulong off_g = extra_g->offset + src_g->view_offs; + const cl_ulong off_beta = extra_beta->offset + src_beta->view_offs; + const cl_ulong off_state = extra_state->offset + src_state->view_offs; + const cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + int idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_q)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_k->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_v->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_g->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_g)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_beta->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_beta)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_state->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_state)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_tokens)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_seqs)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s_off)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &rq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &K)); + + // Subgroup size is 64 for Adreno and 32 for Intel + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); + exit(1); + } + + // For the subgroup-shuffle kernel, we can safely prefer 8 lanes/column for S_v>=128 + // For the subgroup-shuffle kernel: + // S_v >= 128 -> prefer 8 lanes/column (good occupancy & register pressure tradeoff) + // else -> min(S_v, subgroup_size) + int lanes_per_column; + if ((int)S_v >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min((int)S_v, sg_size); + } + + // Max workgroup size for Adreno 750 is 1024 + const int wg_size = sg_size * spw; + + // Ensure lanes_per_column is a power-of-two and divides both S_v and subgroup_size. + // (Required for lane-group shuffle-xor reduction correctness.) + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (((int)S_v % lanes_per_column) != 0) || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; + } + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT(((int)S_v % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); + + const int cols_per_wg = spw * (sg_size / lanes_per_column) * cpl; + GGML_ASSERT(cols_per_wg > 0); + GGML_ASSERT(((int)S_v % cols_per_wg) == 0); + + size_t global_work_size[3]; + size_t local_work_size[3]; + + global_work_size[0] = (size_t) H_v * (size_t) wg_size; + global_work_size[1] = (size_t) n_seqs; + global_work_size[2] = (size_t) S_v / (size_t) cols_per_wg; + + local_work_size[0] = (size_t) wg_size; + local_work_size[1] = 1; + local_work_size[2] = 1; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -17267,8 +18739,8 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_group_norm; break; - case GGML_OP_REPEAT: - if (!any_on_device) { + case GGML_OP_REPEAT: + if (!any_on_device) { return false; } func = ggml_cl_repeat; @@ -17297,6 +18769,14 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_ssm_conv; break; + case GGML_OP_GATED_DELTA_NET: + if (!any_on_device) { + return false; + } + // GDN has 6 source tensors, so it cannot use the standard + // (src0, src1, dst) func signature. Dispatch directly and return. + ggml_cl_gated_delta_net(backend, tensor); + return true; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl index 0c1b3d785ca..2fbd7851d3d 100644 --- a/ggml/src/ggml-opencl/kernels/concat.cl +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -49,3 +49,70 @@ kernel void kernel_concat_f32( *y = *x; } } + +kernel void kernel_concat_f32_pack( + global const char * src0, + ulong offset0, + global const char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int dim, + int ne1, + int ne2, + int ne3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int lsz = get_local_size(0); + int tpr = min(ne0, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne1*ne2*ne3; + if (row >= nrows) { + return; + } + + int i1 = row % ne1; + int t = row / ne1; + int i2 = t % ne2; + int i3 = t / ne2; + + int o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + for (int i0 = lane; i0 < ne0; i0 += tpr) { + global const float * x; + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + } + + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl index 820aa538a34..adbd2e766d2 100644 --- a/ggml/src/ggml-opencl/kernels/cpy.cl +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -183,6 +183,65 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_cpy_f32_f32_pack( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int lsz = get_local_size(0); + int tpr = min(ne00, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne01*ne02*ne03; + if (row >= nrows) { + return; + } + + int i01 = row % ne01; + int t = row / ne01; + int i02 = t % ne02; + int i03 = t / ne02; + + // linear index of the first element of this row, unflattened over dst dims + long n = (long)row * ne00; + int i3 = (int)(n / ((long)ne2*ne1*ne0)); + long rm = n - (long)i3*ne2*ne1*ne0; + int i2 = (int)(rm / ((long)ne1*ne0)); + rm -= (long)i2*ne1*ne0; + int i1 = (int)(rm / ne0); + int i0 = (int)(rm - (long)i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = lane; i00 < ne00; i00 += tpr) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + kernel void kernel_cpy_i32_i32( global int * src0, ulong offset0, diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c25eabdd72b..226b127ab3b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -117,6 +117,48 @@ struct block_iq4_nl uint8_t qs[QK4_NL / 2]; }; +//------------------------------------------------------------------------------ +// bf16 to f16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_bf16_to_f16( + global const ushort * src, + global half * dst, + ulong off_dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + dst[i + off_dst] = (half) as_float((uint) src[i] << 16); +} + +//------------------------------------------------------------------------------ +// f16 to bf16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_f16_to_bf16( + global const half * src, + ulong off_src, + global ushort * dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + float f = (float) src[i + off_src]; + uint bits = as_uint(f); + if ((bits & 0x7fffffffu) > 0x7f800000u) { + // nan to quiet nan + dst[i] = (ushort)((bits >> 16) | 0x40u); + } else { + uint rounded = bits + 0x7fffu + ((bits >> 16) & 1u); + dst[i] = (ushort)(rounded >> 16); + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). @@ -495,6 +537,107 @@ kernel void kernel_restore_block_q4_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_0 +// Convert the block_q5_0 format to 3 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_0( + global struct block_q5_0 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_0/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_0( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global struct block_q5_0 * dst +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_0/2; ++i) { + b->qs[i] = qs[i]; + } +} + +kernel void kernel_convert_block_q5_0_noshuffle( + global struct block_q5_0 * src0, + global uchar * dst_q, + global uint * dst_qh, + global half * dst_d +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q5_0_noshuffle( + global uchar * src_q, + global uint * src_qh, + global half * src_d, + global struct block_q5_0 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_0/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + kernel void kernel_convert_block_q5_0_trans4_ns( __global struct block_q5_0 * src0, __global uint * dst_qs, @@ -594,6 +737,119 @@ kernel void kernel_restore_block_q5_0_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_1 +// Convert the block_q5_1 format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_1( + global struct block_q5_1 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + global half * dst_m, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_1( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_1/2; ++i) { + b->qs[i] = qs[i]; + } +} + +kernel void kernel_convert_block_q5_1_noshuffle( + global struct block_q5_1 * src0, + global uchar * dst_q, + global uint * dst_qh, + global half * dst_d, + global half * dst_m +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q5_1_noshuffle( + global uchar * src_q, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_1/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + kernel void kernel_convert_block_q5_1_trans4_ns( __global struct block_q5_1 * src0, __global uint * dst_qs, diff --git a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl new file mode 100644 index 00000000000..319c9829529 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl @@ -0,0 +1,249 @@ +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifndef S_V +#define S_V 128 +#endif +#ifndef KDA +#define KDA 0 +#endif +#ifndef SUBGROUP_SIZE +#define SUBGROUP_SIZE 64 +#endif +#ifndef LANES_PER_COLUMN +#define LANES_PER_COLUMN 8 +#endif +#ifndef COLS_PER_LANE_GROUP +#define COLS_PER_LANE_GROUP 1 +#endif +#ifndef SUBGROUPS_PER_WG +#define SUBGROUPS_PER_WG 1 +#endif +#ifndef USE_QCOM_SUBGROUP_SHUFFLE +#define USE_QCOM_SUBGROUP_SHUFFLE 0 +#endif + +#define WG_SIZE (SUBGROUP_SIZE * SUBGROUPS_PER_WG) +#define LANE_GROUPS_PER_SG (SUBGROUP_SIZE / LANES_PER_COLUMN) +#define COLS_PER_SG (LANE_GROUPS_PER_SG * COLS_PER_LANE_GROUP) +#define COLS_PER_WG (SUBGROUPS_PER_WG * COLS_PER_SG) +#define ROWS_PER_LANE (S_V / LANES_PER_COLUMN) + +#if USE_QCOM_SUBGROUP_SHUFFLE +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +#endif + +// XOR-based parallel sum +// This does a reduction across groups of LANES_PER_COLUMN +static inline float reduce_add_shmem(float partial, __local float * temp, uint lane) { +#if USE_QCOM_SUBGROUP_SHUFFLE + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + partial += qcom_sub_group_shuffle_xor(partial, s, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, partial); + } + return partial; +#else + temp[lane] = partial; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + float other = temp[lane ^ s]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + temp[lane] += other; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + } + const float result = temp[lane]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + return result; +#endif +} + +#define REDUCE_PARTIAL(partial, temp_ptr, lid) \ + ((LANES_PER_COLUMN == 1u) ? (partial) : reduce_add_shmem((partial), (temp_ptr), (lid))) + +// force compiler to optimize kernel for a specific fixed work-group size +__attribute__((reqd_work_group_size(WG_SIZE, 1, 1))) +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gated_delta_net( + global const char * q_buf, ulong off_q, + global const char * k_buf, ulong off_k, + global const char * v_buf, ulong off_v, + global const char * g_buf, ulong off_g, + global const char * beta_buf, ulong off_beta, + global const char * state_buf, ulong off_state, + global char * dst_buf, ulong off_dst, + uint H_v, + uint n_tokens, + uint n_seqs, + uint s_off, + uint sq1, uint sq2, uint sq3, + uint sv1, uint sv2, uint sv3, + uint sb1, uint sb2, uint sb3, + uint H_k, + uint rq3, + float scale, + uint K) { + + global const float * data_q = (global const float *)(q_buf + off_q); + global const float * data_k = (global const float *)(k_buf + off_k); + global const float * data_v = (global const float *)(v_buf + off_v); + global const float * data_g = (global const float *)(g_buf + off_g); + global const float * data_beta = (global const float *)(beta_buf + off_beta); + global const float * data_state = (global const float *)(state_buf + off_state); + global float * data_dst = (global float *)(dst_buf + off_dst); + + const uint head_id = get_group_id(0); + const uint seq_id = get_group_id(1); + const uint tid = (uint)get_local_id(0); + + const uint sg_id = get_sub_group_id(); // subgroup id + const uint sg_lid = get_sub_group_local_id(); // subgroup lane id + + const uint lane = sg_lid % LANES_PER_COLUMN; + const uint lane_group = sg_lid / LANES_PER_COLUMN; + const uint wg_col_base = get_group_id(2) * COLS_PER_WG; + const uint sg_col_base = wg_col_base + sg_id * COLS_PER_SG; + + const uint iq1 = head_id % H_k; // head index for Q and K + const uint iq3 = seq_id / rq3; // seq index for Q and K + + const uint state_size = S_V * S_V; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_base = (seq_id * H_v + head_id) * state_size; + const uint q_off_base = iq3 * sq3 + iq1 * sq1; + const uint v_off_base = seq_id * sv3 + head_id * sv1; + const uint gb_off_base = seq_id * sb3 + head_id * sb1; + const uint state_out_base = (seq_id * H_v + head_id) * state_size; + const uint state_size_per_snap = state_size * H_v * n_seqs; + + __local float reduce_temp[WG_SIZE]; + __local float * temp_ptr = reduce_temp + sg_id * SUBGROUP_SIZE; + + float s_shard[COLS_PER_LANE_GROUP][ROWS_PER_LANE]; + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[cg][r] = data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]; + } + } + + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. + uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + const uint q_off = q_off_base + t * sq2; + const uint k_off = q_off; + const uint v_off = v_off_base + t * sv2; + const uint gb_off = gb_off_base + t * sb2; + const float beta_val = data_beta[gb_off]; + + float k_reg[ROWS_PER_LANE]; + float q_reg[ROWS_PER_LANE]; +#if KDA + float g_exp[ROWS_PER_LANE]; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + g_exp[r] = exp(data_g[gb_off * S_V + i]); + } +#else + const float g_val = exp(data_g[gb_off]); + + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + } +#endif + + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + float v_val = data_v[v_off + col]; + + float kv_shard = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; + kv_shard += gs * k_reg[r]; +#else + kv_shard += s_shard[cg][r] * k_reg[r]; +#endif + } + +#if !KDA + kv_shard *= g_val; // Applied once instead of ROWS_PER_LANE times +#endif + + const float kv_col = REDUCE_PARTIAL(kv_shard, temp_ptr, sg_lid); + + const float delta_col = (v_val - kv_col) * beta_val; + + float attn_partial = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; +#else + float gs = g_val * s_shard[cg][r]; +#endif + s_shard[cg][r] = gs + k_reg[r] * delta_col; + attn_partial += s_shard[cg][r] * q_reg[r]; + } + const float attn_col = REDUCE_PARTIAL(attn_partial, temp_ptr, sg_lid); + + if (lane == 0) { + data_dst[attn_off + col] = attn_col * scale; + } + } + attn_off += S_V * H_v; + + if (K > 1u) { + const int target_slot = (int)n_tokens - 1 - (int)t; + if (target_slot >= 0 && target_slot < (int)K) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + const uint slot_base = s_off + (uint)target_slot * state_size_per_snap + state_out_base; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } + } + } + + if (K == 1u) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..1d6bd48005e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl @@ -0,0 +1,131 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_0_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..94b4ef6cacc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl @@ -0,0 +1,134 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_1_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + global const half * src0_m, // A mins + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + global const half * min_ptr = src0_m + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + half4 minv = vload4(0, min_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..c228f717a94 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl @@ -0,0 +1,291 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_0 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_0_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_0); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..daf1308ea4b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl @@ -0,0 +1,294 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_1 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_1_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + global half2 * src0_m, // A mins + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + __private uint4 regA; + __private half2 regS; + __private half2 regM; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_1); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + regM = src0_m[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index c2962edc983..9ae4fff09fc 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -82,21 +82,27 @@ kernel void kernel_get_rows_f32( src1 = (global int*)((global char*)src1 + offset1); dst = (global float*)((global char*)dst + offsetd); - int i10 = get_group_id(0); - int i11 = get_group_id(1); - int i12 = get_group_id(2); + int nchunks = get_num_groups(0) / ne10; + int g = get_group_id(0); + int i10 = g / nchunks; + int chunk = g - i10 * nchunks; + int i11 = get_group_id(1); + int i12 = get_group_id(2); int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; int i03 = i12; - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - if (ind >= ne00) { - return; - } - ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; + global float * dst_row = (global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1); + global float * src_row = (global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03); + + int span = (ne00 + nchunks - 1) / nchunks; + int start = chunk * span; + int end = min(start + span, ne00); + + for (int ind = start + get_local_id(0); ind < end; ind += get_local_size(0)) { + dst_row[ind] = src_row[ind]; } } diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl new file mode 100644 index 00000000000..1e980a478a8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl @@ -0,0 +1,173 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_0_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs_ptr = src0_qs + ib*4 + iqs; + uchar4 q = *qs_ptr; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = (convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) - 16.0f) * d; + float4 v2 = (convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) - 16.0f) * d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl new file mode 100644 index 00000000000..ba06be54697 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl @@ -0,0 +1,175 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_1_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs = src0_qs + ib*4 + iqs; + uchar4 q = *qs; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) * d + m; + float4 v2 = convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) * d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl new file mode 100644 index 00000000000..6d8c9e8f037 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl @@ -0,0 +1,241 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +struct block_q5_0 { + half d; + uchar qh[4]; + uchar qs[QK5_0 / 2]; +}; + +inline float block_q5_0_dot_y( + global const struct block_q5_0 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 6 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 2)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_0 * x = (global struct block_q5_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl new file mode 100644 index 00000000000..34ec133d398 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl @@ -0,0 +1,243 @@ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +inline float block_q5_0_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_0/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 0*nb*(QK5_0/2), qh + ib + 0*nb, d + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 1*nb*(QK5_0/2), qh + ib + 1*nb, d + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 2*nb*(QK5_0/2), qh + ib + 2*nb, d + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 3*nb*(QK5_0/2), qh + ib + 3*nb, d + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl new file mode 100644 index 00000000000..1480f675038 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl @@ -0,0 +1,243 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +struct block_q5_1 { + half d; + half m; + uchar qh[4]; + uchar qs[QK5_1 / 2]; +}; + +inline float block_q5_1_dot_y( + global const struct block_q5_1 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 8 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 4)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_1 * x = (global struct block_q5_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl new file mode 100644 index 00000000000..57c2f140958 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl @@ -0,0 +1,247 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +inline float block_q5_1_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_1/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global half * ms = (global half *) src0_m + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 0*nb*(QK5_1/2), qh + ib + 0*nb, d + ib + 0*nb, ms + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 1*nb*(QK5_1/2), qh + ib + 1*nb, d + ib + 1*nb, ms + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 2*nb*(QK5_1/2), qh + ib + 2*nb, d + ib + 2*nb, ms + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 3*nb*(QK5_1/2), qh + ib + 3*nb, d + ib + 3*nb, ms + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl index 86fe09c6dd6..57b90c05ae5 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl @@ -33,13 +33,15 @@ inline float block_q_6_K_dot_y_flat( global uchar * blk_qh, global char * blk_scales, global half * blk_d, - global float * yy, int ib, int ip, int is, - int l0 + int l0, + float4 y0, + float4 y1, + float4 y2, + float4 y3 ) { - int y_offset = 128*ip + l0; int q_offset_l = 64*ip + l0; int q_offset_h = 32*ip + l0; @@ -48,36 +50,28 @@ inline float block_q_6_K_dot_y_flat( global uchar * qh = blk_qh + ib*64 + q_offset_h; global char * sc = blk_scales + ib*16 + is; - global float * y = yy + ib * QK_K + y_offset; - float dall = blk_d[ib]; - float sumf = 0; - float4 sums = {0.f, 0.f, 0.f, 0.f}; - - sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f); - - sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); - - return sumf; + // Vectorized loads: 3 uchar4 weight loads instead of 12 scalar byte reads. + // q_offset_l/h are 4-aligned, so these are aligned vector loads. + uchar4 q1v = vload4(0, q1); + uchar4 q2v = vload4(0, q2); + uchar4 qhv = vload4(0, qh); + + int4 q1i = convert_int4(q1v); + int4 q2i = convert_int4(q2v); + int4 qhi = convert_int4(qhv); + + // Reconstruct the four 6-bit weight groups (low/high nibble of ql OR'd with the + // matching 2-bit plane of qh), same arithmetic as the scalar version, then dot() + // against the cached activation lanes. + float4 w0 = convert_float4((q1i & 0xF) | ((qhi & Q6_K_MASK1) << 4)) - 32.f; + float4 w1 = convert_float4((q2i & 0xF) | ((qhi & Q6_K_MASK2) << 2)) - 32.f; + float4 w2 = convert_float4((q1i >> 4) | ((qhi & Q6_K_MASK3) )) - 32.f; + float4 w3 = convert_float4((q2i >> 4) | ((qhi & Q6_K_MASK4) >> 2)) - 32.f; + + return dall * (dot(y0, w0) * sc[0] + dot(y1, w1) * sc[2] + + dot(y2, w2) * sc[4] + dot(y3, w3) * sc[6]); } #undef N_DST @@ -89,7 +83,7 @@ inline float block_q_6_K_dot_y_flat( #define N_SIMDGROUP 2 #define N_SIMDWIDTH 16 #elif defined (ADRENO_GPU) -#define N_DST 4 +#define N_DST 16 #define N_SIMDGROUP 2 #define N_SIMDWIDTH 64 #endif @@ -146,49 +140,39 @@ kernel void kernel_mul_mv_q6_K_f32_flat( global half * blk_d = (global half *) src0_d + offset_src0_d; global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; - int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 - int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int tid = get_sub_group_local_id()%(N_SIMDWIDTH/BLOCK_STRIDE); // within-super-block part, 0..15 + int ix = get_sub_group_local_id()/(N_SIMDWIDTH/BLOCK_STRIDE); // super-block selector, 0..BLOCK_STRIDE-1 int ip = tid/8; // first or second half of (super) block (0 or 1) int il = tid%8; // each half has 8 parts, one per scale int n = 4; // 4 scales at a time (and 4 sums) int l0 = n*il; // offset into half-block, 0..28 int is = 8*ip + l0/16; // 0, 1, 8, 9 - float4 sumf = 0; + float sumf[N_DST]; + for (int row = 0; row < N_DST; row++) { + sumf[row] = 0.f; + } for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { - if (first_row + 0 < ne01) { - sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0); - } - if (first_row + 1 < ne01) { - sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0); - } - if (first_row + 2 < ne01) { - sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0); - } - if (first_row + 3 < ne01) { - sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0); + global float * y = yy + ib * QK_K + 128*ip + l0; + float4 y0 = vload4(0, y + 0); + float4 y1 = vload4(0, y + 32); + float4 y2 = vload4(0, y + 64); + float4 y3 = vload4(0, y + 96); + + for (int row = 0; row < N_DST; row++) { + if (first_row + row < ne01) { + sumf[row] += block_q_6_K_dot_y_flat( + blk_ql + row*nb*128, blk_qh + row*nb*64, blk_scales + row*nb*16, blk_d + row*nb, + ib, ip, is, l0, y0, y1, y2, y3); + } } } - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), - sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), - sub_group_reduce_add(sumf.s3) - ); - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + for (int row = 0; row < N_DST; row++) { + float tot = sub_group_reduce_add(sumf[row]); + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } } diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 6d19538215e..d8bb3638dfd 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -45,6 +45,7 @@ namespace syclexp = sycl::ext::oneapi::experimental; #define GGML_COMMON_IMPL_SYCL #define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building. #define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building +#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention. /* suppress warning spam */ #pragma clang diagnostic push @@ -224,6 +225,7 @@ struct sycl_device_info { int max_wg_per_cu; // max work groups per compute unit - refer to // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; sycl_hw_info hw_info; optimize_feature opt_feature; @@ -244,6 +246,8 @@ struct ggml_sycl_device_info { const ggml_sycl_device_info & ggml_sycl_info(); +static constexpr size_t SYCL_BUFFER_ALIGNMENT = 128; + struct ggml_sycl_pool { virtual ~ggml_sycl_pool() = default; diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 576f19d79ae..65593402e7d 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -107,6 +107,19 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template <typename dst_t> +static void dequantize_row_q3_K_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K_reorder(vx, y, item_ct1, nb); + }); +} + template <typename dst_t> static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -652,7 +665,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q4_K_sycl_reorder; @@ -730,7 +747,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 2324bfacd22..ca8cd96c08c 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -20,6 +20,10 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, const int iqs, dfloat2 &v); +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m); +#endif + static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -90,6 +94,474 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q4_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q4_K * x = (const block_q4_K *) vx; + const sycl::half2 dm = x[ib].dm; + const float dall = dm[0]; + const float dmin = dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int off = in & 31; + const int qsi = 32 * il + off; + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const uint8_t q = x[ib].qs[qsi]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + return sycl::fma((dfloat) qv, (dfloat) (dall * sc), (dfloat) (-dmin * m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q4_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q2_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q2_K * x = (const block_q2_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int g = r / 32; + const int l = r % 32; + const int is = 8 * n + l / 16; + + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t sc = x[ib].scales[is + 2 * g]; + const float d = dall * (sc & 0xF); + const float m = dmin * (sc >> 4); + + return sycl::fma((dfloat) ((q >> (2 * g)) & 3), (dfloat) d, (dfloat) (-m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q2_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q3_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q3_K * x = (const block_q3_K *) vx; + const float d_all = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int j = r / 32; + const int l = r % 32; + + const int is0 = l / 16; + const int is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + const uint8_t m = 1 << (4 * n + j); + + const int8_t us = is < 4 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 8] >> 0) & 3) << 4) : + is < 8 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 4] >> 2) & 3) << 4) : + is < 12 ? (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is + 0] >> 4) & 3) << 4) : + (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t h = x[ib].hmask[l]; + const int8_t qv = ((q >> shift) & 3) - ((h & m) ? 0 : 4); + + return (dfloat) (dl * qv); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q3_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q5_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q5_K * x = (const block_q5_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int ir = (in & 31) / 2; + const int iq = in & 1; + + const uint8_t q = x[ib].qs[32 * il + 2 * ir + iq]; + const uint8_t h = x[ib].qh[2 * ir + iq]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const float d = dall * sc; + const float mn = dmin * m; + const uint8_t hm = 1 << (2 * il + (in >= 32 ? 1 : 0)); + + return sycl::fma((dfloat) (qv + ((h & hm) ? 16 : 0)), (dfloat) d, (dfloat) (-mn)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q5_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q6_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q6_K * x = (const block_q6_K *) vx; + const float d = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ip = idx / 128; + const int in = idx % 128; + const int il = in & 31; + const int ig = in / 32; + const int is = 8 * ip + il / 16; + + const uint8_t ql0 = x[ib].ql[64 * ip + il]; + const uint8_t ql1 = x[ib].ql[64 * ip + il + 32]; + const uint8_t qh = x[ib].qh[32 * ip + il]; + const int8_t * sc = x[ib].scales + is; + + uint8_t qv; + int8_t scale; + if (ig == 0) { + qv = (ql0 & 0xF) | (((qh >> 0) & 3) << 4); + scale = sc[0]; + } else if (ig == 1) { + qv = (ql1 & 0xF) | (((qh >> 2) & 3) << 4); + scale = sc[2]; + } else if (ig == 2) { + qv = (ql0 >> 4) | (((qh >> 4) & 3) << 4); + scale = sc[4]; + } else { + qv = (ql1 >> 4) | (((qh >> 6) & 3) << 4); + scale = sc[6]; + } + + return (dfloat) (d * scale * ((int8_t) qv - 32)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q6_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_mxfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_mxfp4 * x = (const block_mxfp4 *) vx; + const float d = ggml_sycl_e8m0_to_fp32(x[ib].e); + const uint8_t q = x[ib].qs[iqs]; + + v.x() = d * kvalues_mxfp4[q & 0xF] * 0.5f; + v.y() = d * kvalues_mxfp4[q >> 4] * 0.5f; +} + +static __dpct_inline__ void dequantize_q1_0(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q1_0 * x = (const block_q1_0 *) vx; + const dfloat d = x[ib].d; + + const int bit_index_0 = iqs + 0; + const int bit_index_1 = iqs + 1; + + const int bit_0 = (x[ib].qs[bit_index_0 / 8] >> (bit_index_0 % 8)) & 1; + const int bit_1 = (x[ib].qs[bit_index_1 / 8] >> (bit_index_1 % 8)) & 1; + + v.x() = (2 * bit_0 - 1) * d; + v.y() = (2 * bit_1 - 1) * d; +} + +static __dpct_inline__ void dequantize_nvfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_nvfp4 & xb = ((const block_nvfp4 *) vx)[ib]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int sub = idx / QK_NVFP4_SUB; + const int j = idx % QK_NVFP4_SUB; + const int jh = j % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + jh]; + const uint8_t qv = (j < (QK_NVFP4_SUB / 2)) ? (q & 0x0F) : (q >> 4); + + return d * kvalues_mxfp4[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq2_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * aux8 = (const uint8_t *) q2; + const uint8_t * grid = (const uint8_t *) (iq2xxs_grid + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * grid = (const uint8_t *) (iq2xs_grid + (q2[il] & 511)); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_s * x = (const block_iq2_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 0x300); + const uint8_t * grid = (const uint8_t *) (iq2s_grid + grid_id); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = x[ib].qs[QK_K / 8 + 4 * ib8 + il]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * q3 = x[ib].qs + 8 * ib8; + const uint16_t * gas = (const uint16_t *) (x[ib].qs + QK_K / 4) + 2 * ib8; + const uint8_t * grid1 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 0]); + const uint8_t * grid2 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_s * x = (const block_iq3_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * qs = x[ib].qs + 8 * ib8; + const uint16_t grid1_id = qs[2 * il + 0] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 256); + const uint16_t grid2_id = qs[2 * il + 1] | ((x[ib].qh[ib8] << (7 - 2 * il)) & 256); + const uint8_t * grid1 = (const uint8_t *) (iq3s_grid + grid1_id); + const uint8_t * grid2 = (const uint8_t *) (iq3s_grid + grid2_id); + const float d = (float) x[ib].d * (1 + 2 * ((x[ib].scales[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf)); + const uint8_t signs = x[ib].signs[4 * ib8 + il]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_s * x = (const block_iq1_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const float delta = (x[ib].qh[ib8] & 0x8000) ? (-1.f - IQ1S_DELTA) : (-1.f + IQ1S_DELTA); + const float d = (float) x[ib].d * (2 * ((x[ib].qh[ib8] >> 12) & 7) + 1); + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((x[ib].qh[ib8] >> (3 * il)) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_m(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_m * x = (const block_iq1_m *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * sc = (const uint16_t *) x[ib].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + const int ib16 = 2 * ib8 + il / 2; + const float d = (float) scale.f16 * (2 * ((sc[ib16 / 4] >> (3 * (ib16 % 4))) & 0x7) + 1); + + const uint8_t qh = x[ib].qh[2 * ib8 + il / 2]; + const float delta = (qh & (0x08 << (4 * (il % 2)))) ? (-1.f - IQ1M_DELTA) : (-1.f + IQ1M_DELTA); + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((qh >> (4 * (il % 2))) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_M dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq4_nl(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_iq4_nl * x = (const block_iq4_nl *) vx; + const float d = (float) x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + if (idx < 16) { + return d * kvalues_iq4nl[x[ib].qs[idx] & 0xF]; + } + return d * kvalues_iq4nl[x[ib].qs[idx - 16] >> 4]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq4_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq4_xs * x = (const block_iq4_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int byte_idx = (r < 16) ? r : (r - 16); + const uint8_t q = x[ib].qs[16 * ib8 + byte_idx]; + const uint8_t qv = (r < 16) ? (q & 0x0F) : (q >> 4); + + const float d = (float) x[ib].d * ((((x[ib].scales_l[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf) | + (((x[ib].scales_h >> (2 * ib8)) & 3) << 4)) - 32); + return d * kvalues_iq4nl[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ4_XS dequantize not supported for QK_K != 256"); +#endif +} + static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_0 * x = (const block_q5_0 *) vx; @@ -390,6 +862,63 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri } +template<typename dst_t> +static void dequantize_block_q3_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { +#if QK_K == 256 + const int64_t i = item_ct1.get_group(2); + if (i >= n_blocks) { + return; + } + + const uint8_t * base = static_cast<const uint8_t *>(vx); + const size_t qs_offset = i * (QK_K / 4); + const size_t hmask_offset = n_blocks * (QK_K / 4) + i * (QK_K / 8); + const size_t scales_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + i * 12; + const size_t d_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + n_blocks * 12 + + i * sizeof(ggml_half); + + const uint8_t * qs = base + qs_offset; + const uint8_t * hmask = base + hmask_offset; + const uint8_t * scales = base + scales_offset; + const float d_all = static_cast<float>(*reinterpret_cast<const ggml_half *>(base + d_offset)); + + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r / 2; + const int64_t is0 = r % 2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4 * n; + const int64_t is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + uint8_t m = 1 << (4 * n + j); + + uint8_t us = is < 4 + ? (scales[is - 0] & 0xF) | (((scales[is + 8] >> 0) & 3) << 4) + : is < 8 + ? (scales[is - 0] & 0xF) | (((scales[is + 4] >> 2) & 3) << 4) + : is < 12 + ? (scales[is - 8] >> 4) | (((scales[is + 0] >> 4) & 3) << 4) + : (scales[is - 8] >> 4) | (((scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + + dst_t * y = yy + i * QK_K + 128 * n + 32 * j; + const uint8_t * q = qs + 32 * n; + const uint8_t * hm = hmask; + + for (int l = l0; l < l0 + 4; ++l) { + y[l] = dl * ((int8_t) ((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(item_ct1); + GGML_UNUSED(n_blocks); + GGML_ABORT("Q3_K reorder dequantize not supported for QK_K != 256"); +#endif +} + #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 4ae431a962e..d80b0a38219 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -501,6 +501,103 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q3_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * (QK_K/4)] [hmask: nb * (QK_K/8)] [scales: nb * 12] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * hmask_base = qs_base + (size_t)nb * (QK_K / 4); + const uint8_t * scales_base = hmask_base + (size_t)nb * (QK_K / 8); + const sycl::half * d_base = (const sycl::half *)(scales_base + (size_t)nb * 12); + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = qs_base + bi * (QK_K / 4) + q_offset; + const uint8_t * h = hmask_base + bi * (QK_K / 8) + l0; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * 12); + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = d_base[bi]; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(ncols); + GGML_UNUSED(item_ct1); + GGML_ABORT("Q3_K reorder DMMV not supported for QK_K != 256"); +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:6: The total declared local variable size in device function dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register @@ -1440,6 +1537,22 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q3_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q3_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, float *dst, const int ncols, const int nrows, @@ -1581,7 +1694,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q3_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp index 03f0c2623c8..c6cc13cfb00 100644 --- a/ggml/src/ggml-sycl/fattn-common.hpp +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -1031,7 +1031,7 @@ void launch_fattn( auto KV_max_ptr_ct1 = KV_max.ptr; cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_mask_to_KV_max<ncols1, warp_size>( mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33, @@ -1149,7 +1149,7 @@ void launch_fattn( auto K_ne_ct6 = K->ne[2]; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1, Q_ne_ct2, Q_ne_ct3, Q_ne_ct4, @@ -1169,7 +1169,7 @@ void launch_fattn( auto KQV_data_ct2 = (float *) KQV->data; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_combine_results<DV>( dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks, diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 9c2449aba0c..239e00bd7e5 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -44,9 +44,9 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output state += state_out_offset; @@ -63,9 +63,8 @@ void gated_delta_net_sycl(const float * q, s_shard[r] = curr_state[i]; } - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -144,7 +143,7 @@ void gated_delta_net_sycl(const float * q, // Write state back to global memory if constexpr (keep_rs_t) { - const int target_slot = t - shift; + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -315,8 +314,8 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index ca457454775..298f247f84e 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -129,11 +129,11 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr GGML_UNUSED(ctx); } -template <typename src0_t> +template <typename src0_t, typename dst_t> static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const src0_t *src0_dd, const int32_t *src1_dd, - float *dst_dd, queue_ptr stream) { + dst_t *dst_dd, queue_ptr stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -170,7 +170,7 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32 ); GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type)); GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type)); @@ -191,6 +191,66 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_I32: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const int32_t *)dst->src[0]->data, + src1_i32, (int32_t *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q1_0: + get_rows_sycl<QK1_0, 1, dequantize_q1_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_MXFP4: + get_rows_sycl<QK_MXFP4, 2, dequantize_mxfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_NVFP4: + get_rows_sycl<QK_NVFP4, 1, dequantize_nvfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_S: + get_rows_sycl<QK_K, 1, dequantize_iq2_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq3_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_S: + get_rows_sycl<QK_K, 1, dequantize_iq1_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_M: + get_rows_sycl<QK_K, 1, dequantize_iq1_m>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_S: + get_rows_sycl<QK_K, 1, dequantize_iq3_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_NL: + get_rows_sycl<QK4_NL, 1, dequantize_iq4_nl>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_XS: + get_rows_sycl<QK_K, 1, dequantize_iq4_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q2_K: + get_rows_sycl<QK_K, 1, dequantize_q2_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q3_K: + get_rows_sycl<QK_K, 1, dequantize_q3_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q4_0: get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -199,6 +259,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q4_K: + get_rows_sycl<QK_K, 1, dequantize_q4_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q5_0: get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -207,6 +271,14 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q5_K: + get_rows_sycl<QK_K, 1, dequantize_q5_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q6_K: + get_rows_sycl<QK_K, 1, dequantize_q6_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q8_0: get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b3fbb621196..3f246e8672d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -19,6 +19,7 @@ #include <cstdlib> #include <float.h> #include <limits> +#include <optional> #include <stdint.h> #include <stdio.h> #include <vector> @@ -37,6 +38,11 @@ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp> #endif +#if SYCL_EXT_ONEAPI_VIRTUAL_MEM +# include <sycl/ext/oneapi/virtual_mem/physical_mem.hpp> +# include <sycl/ext/oneapi/virtual_mem/virtual_mem.hpp> +# define GGML_SYCL_USE_VMM +#endif #include <sycl/half_type.hpp> #include "ggml.h" @@ -70,6 +76,7 @@ int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; +int g_ggml_sycl_enable_vmm = 1; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; int g_ggml_sycl_use_async_mem_op_requested = 1; @@ -96,13 +103,30 @@ static ggml_sycl_device_info ggml_sycl_init() { // GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__); // #endif for (int i = 0; i < info.device_count; ++i) { - info.devices[i].vmm = 0; dpct::device_info prop; auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); +#if !defined(GGML_SYCL_USE_VMM) + info.devices[i].vmm = 0; +#else + info.devices[i].vmm = device.has(sycl::aspect::ext_oneapi_virtual_mem); + if (info.devices[i].vmm) { + // NB: SYCL's get_mem_granularity always returns the _minimum_ granularity, + // but the L0 API requires a larger page size for allocs above 2 MiB and + // rejects non-multiples with UR_RESULT_ERROR_INVALID_VALUE [sic]. + // Here we clamp it to 2 MiB for simplicity, but other devices may require + // calling zeVirtualMemQueryPageSize or yet unexposed public API. + const size_t physical_page = 2ull << 20; // 2 MiB + info.devices[i].vmm_granularity = std::max<size_t>( + sycl::ext::oneapi::experimental::get_mem_granularity( + device, sycl::context(device)), + physical_page); + } +#endif + info.default_tensor_split[i] = total_vram; total_vram += prop.get_global_mem_size(); @@ -234,6 +258,7 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); + g_ggml_sycl_enable_vmm = get_sycl_env("GGML_SYCL_ENABLE_VMM", 1); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); #ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); @@ -275,6 +300,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); #endif +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: no\n"); +#endif GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); @@ -293,6 +323,11 @@ static void ggml_check_sycl() try { GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); +#endif +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: %d\n", g_ggml_sycl_enable_vmm); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: virtual memory extension is not available\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); @@ -754,7 +789,7 @@ catch (sycl::exception const &exc) { } static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -1177,7 +1212,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(gg } static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -1462,6 +1497,121 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +// pool with virtual memory management +#if defined(GGML_SYCL_USE_VMM) +struct ggml_sycl_pool_vmm : public ggml_sycl_pool { + static const size_t SYCL_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB + + int device; + sycl::context ctx; + sycl::device dev; + + uintptr_t pool_addr = 0; + size_t pool_used = 0; + size_t pool_size = 0; + size_t granularity; + + // physical_mem owns the commits (unlike cuMemMap) + struct mapping { + sycl::ext::oneapi::experimental::physical_mem phys; + void * map_ptr; + }; + std::vector<mapping> mappings; + + explicit ggml_sycl_pool_vmm(queue_ptr qptr_, int device_) : + device(device_), + ctx(qptr_->get_context()), + dev(qptr_->get_device()), + granularity(ggml_sycl_info().devices[device_].vmm_granularity) { + } + + ~ggml_sycl_pool_vmm() { + if (pool_addr == 0) { + return; + } + + // Per spec, unmap must (a) match the exact (ptr, size) of an earlier + // physical_mem::map() call and (b) precede destruction of the + // physical_mem objects (their dtors won't unmap). + for (auto & m : mappings) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::unmap( + m.map_ptr, m.phys.size(), ctx))); + } + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::free_virtual_mem( + pool_addr, SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + void * alloc(size_t size, size_t * actual_size) override { + // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types + size = GGML_PAD(size, SYCL_BUFFER_ALIGNMENT); + + size_t avail = pool_size - pool_used; + + if (size > avail) { + // round up to the next multiple of the granularity + size_t reserve_size = GGML_PAD(size - avail, granularity); + + GGML_ASSERT(pool_size + reserve_size <= SYCL_POOL_VMM_MAX_SIZE); + + // allocate more physical memory + std::optional<sycl::ext::oneapi::experimental::physical_mem> phys; + SYCL_CHECK(CHECK_TRY_ERROR(phys.emplace(dev, ctx, reserve_size))); + + // reserve virtual address space (if not already reserved) + if (pool_addr == 0) { + SYCL_CHECK(CHECK_TRY_ERROR( + pool_addr = sycl::ext::oneapi::experimental::reserve_virtual_mem( + SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + // map at the end of the pool + void * map_ptr = nullptr; + SYCL_CHECK(CHECK_TRY_ERROR( + map_ptr = phys->map(pool_addr + pool_size, reserve_size, + sycl::ext::oneapi::experimental::address_access_mode::read_write))); + + // stash these so we could unmap this exact range in dtor + mappings.push_back({ + std::move(*phys), + map_ptr, + }); + + // add to the pool + pool_size += reserve_size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: size increased to %llu MB (reserved %llu MB)\n", + device, (unsigned long long) (pool_size/1024/1024), + (unsigned long long) (reserve_size/1024/1024)); +#endif + } + + GGML_ASSERT(pool_addr != 0); + + void * ptr = reinterpret_cast<void *>(pool_addr + pool_used); + *actual_size = size; + pool_used += size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: allocated %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + return ptr; + } + + void free(void * ptr, size_t size) override { +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: freed %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + pool_used -= size; + + // all deallocations must be in reverse order of the allocations + GGML_ASSERT(ptr == reinterpret_cast<void *>(pool_addr + pool_used)); + } +}; +#endif // defined(GGML_SYCL_USE_VMM) + struct ggml_sycl_pool_host : public ggml_sycl_pool { queue_ptr qptr; int device; @@ -1542,20 +1692,19 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(que } std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { - // TBD: NO VMM support - // if (ggml_sycl_info().devices[device].vmm) { - // return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device)); - // } - return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device)); +#if defined(GGML_SYCL_USE_VMM) + if (g_ggml_sycl_enable_vmm && ggml_sycl_info().devices[device].vmm) { + return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(qptr, device)); + } +#endif // defined(GGML_SYCL_USE_VMM) + return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device)); } + std::unique_ptr<ggml_sycl_fattn_kv_buffers> ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) { return std::unique_ptr<ggml_sycl_fattn_kv_buffers>(new ggml_sycl_fattn_kv_buffers(qptr, device)); } -// TBD pool with virtual memory management -// struct ggml_sycl_pool_vmm : public ggml_sycl_pool - /// kernels typedef void (*ggml_sycl_op_mul_mat_t)( ggml_backend_sycl_context & ctx, @@ -3400,6 +3549,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: return true; + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -3423,6 +3573,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -3642,6 +3793,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d return true; } +static bool reorder_qw_q3_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q3_K) == 0); + GGML_ASSERT(offset % sizeof(block_q3_K) == 0); + + const int nblocks = size / sizeof(block_q3_K); + + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + auto * qs_ptr = data_device; + auto * hmask_ptr = qs_ptr + (QK_K / 4) * nblocks; + auto * scales_ptr = hmask_ptr + (QK_K / 8) * nblocks; + sycl::half * d_ptr = (sycl::half *) (scales_ptr + 12 * nblocks); + + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q3_K * x = (const block_q3_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 4; ++j) { + qs_ptr[ib * (QK_K / 4) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < QK_K / 8; ++j) { + hmask_ptr[ib * (QK_K / 8) + j] = x[ib].hmask[j]; + } + + for (int j = 0; j < 12; ++j) { + scales_ptr[ib * 12 + j] = x[ib].scales[j]; + } + + d_ptr[ib] = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; +} + static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q5_K) == 0); GGML_ASSERT(offset % sizeof(block_q5_K) == 0); @@ -3754,6 +3953,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q8_0: return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); + case GGML_TYPE_Q3_K: + return reorder_qw_q3_k(data_device, size, 0, stream); case GGML_TYPE_Q4_K: return reorder_qw_q4_k(data_device, size, 0, stream); case GGML_TYPE_Q5_K: @@ -3770,7 +3971,9 @@ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_ten return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. - dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; + // ne[1] <= 8 so multi-column decode (spec / MTP verify) also bootstraps the reorder; + // all reorderable types have a _switch_ncols kernel. + dst->src[1]->ne[1] <= 8 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; } static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, @@ -5100,13 +5303,31 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { + case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_0: return true; default: diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 49998f13ba8..cf2b59576aa 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -56,6 +56,65 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r } } +template <typename reorder_vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_reorder_ncols(const void * __restrict__ vx, const void * __restrict__ vy, + float * __restrict__ dst, const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + float partial_sum[ncols_dst] = {0.0f}; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; + + const auto bx_offset = block_type::get_block_offset(ibx, nblocks); + const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx); + const int iby = i * block_type::block_to_q8_1_ratio(); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const char * vy_j = (const char *)vy + j * stride_col_y_bytes; + const int8_t * q8_1_quant_ptr = (const int8_t *)vy_j + iby * QK8_1; + const sycl::half2* q8_1_ds_ptr = (const sycl::half2 *)(vy_j + ncols + iby * sizeof(sycl::half2)); + + partial_sum[j] += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs); + } + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + float sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum[j], std::plus<>()); + + if (sg.leader()) { + dst[j * stride_col_dst + row] = sum; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl> static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) { @@ -100,6 +159,70 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ } } +template <int qk, int qi, typename block_q_t, int vdr, + vec_dot_q_sycl_t vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_ncols( + const void * __restrict__ vx, + const void * __restrict__ vy, + float * __restrict__ dst, + const int ncols, + const int nrows, + const int stride_col_y, + const int stride_col_dst, + const sycl::nd_item<3> & item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + // partial sums: one per output column + float tmp[ncols_dst] = {0.0f}; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); + i < blocks_per_row; + i += blocks_per_warp) { + + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + // read weight block once, dot against all columns + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + tmp[j] += vec_dot_q_sycl(&x[ibx], &y[j * stride_col_y + iby], iqs); + } + } + } + + // reduce within subgroup +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp[j] += dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), tmp[j], mask); + } + } + + if (item_ct1.get_local_id(2) == 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + dst[j * stride_col_dst + row] = tmp[j]; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr> static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, const void *__restrict__ vy, @@ -553,6 +676,45 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); @@ -571,6 +733,45 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * } } +template <int ncols_dst> +static void mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_0, block_q4_0, + VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -595,6 +796,45 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_1, block_q4_1, + VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_1 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_MXFP4 == 0); @@ -613,6 +853,45 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +template <int ncols_dst> +static void mul_mat_vec_mxfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_MXFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_MXFP4, QI_MXFP4, block_mxfp4, + VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_mxfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_mxfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_mxfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_mxfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_mxfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_mxfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_mxfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_mxfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for MXFP4 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_NVFP4 == 0); @@ -631,6 +910,45 @@ static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float } } +template <int ncols_dst> +static void mul_mat_vec_nvfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_NVFP4, QI_NVFP4, block_nvfp4, + VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_nvfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_nvfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_nvfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_nvfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_nvfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_nvfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_nvfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_nvfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for NVFP4 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -655,6 +973,45 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_0, QI5_0, block_q5_0, + VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -679,6 +1036,45 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_1, QI5_1, block_q5_1, + VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_1 multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK8_0 == 0); @@ -698,6 +1094,45 @@ static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -722,6 +1157,45 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK8_0, QI8_0, block_q8_0, + VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -746,6 +1220,45 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q2_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI2_K, block_q2_K, + VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q2_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q2_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q2_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q2_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q2_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q2_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q2_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q2_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q2_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -770,6 +1283,105 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +template <int ncols_dst> +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q3_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K reorder multi-col MMVQ", ncols_dst); + } +} + +template <int ncols_dst> +static void mul_mat_vec_q3_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI3_K, block_q3_K, + VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q3_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q3_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q3_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q3_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q3_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q3_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q3_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q3_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K multi-col MMVQ", ncols_dst); + } +} + + static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -794,6 +1406,51 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_K, block_q4_K, + VDR_Q4_K_Q8_1_MMVQ, + vec_dot_q4_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -814,6 +1471,44 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K reorder multi-col MMVQ", ncols_dst); + } +} static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -839,6 +1534,51 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI5_K, block_q5_K, + VDR_Q5_K_Q8_1_MMVQ, + vec_dot_q5_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -859,6 +1599,45 @@ static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q5_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K reorder multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -877,6 +1656,46 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, }); }); } + +template <int ncols_dst> +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q6_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -901,6 +1720,51 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q6_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI6_K, block_q6_K, + VDR_Q6_K_Q8_1_MMVQ, + vec_dot_q6_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q6_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q6_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q6_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q6_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q6_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q6_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q6_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q6_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -1097,6 +1961,51 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_iq4_xs_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_XS/4, block_iq4_xs, + 1, + vec_dot_iq4_xs_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_iq4_xs_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for IQ4_XS multi-col MMVQ", ncols_dst); + } +} + void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, @@ -1123,44 +2032,163 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q4_0: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q8_0: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); - reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n"); mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q2_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n"); + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n"); mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1168,9 +2196,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q5_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n"); mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1178,9 +2224,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q6_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n"); mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1210,13 +2274,43 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_MXFP4: - mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_NVFP4: - mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; default: GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 806028ef3a3..95287f17510 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -58,6 +58,31 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t<GGML_TYPE_Q3_K> { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI3_K; + static constexpr uint32_t qr = QR3_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + // Reordered layout: [qs (QK_K/4 per block)] [hmask (QK_K/8 per block)] [scales] [d] + static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) { + auto qs_offset = block_index * (QK_K / 4); + auto hmask_offset = n_blocks * (QK_K / 4) + block_index * (QK_K / 8); + return { qs_offset, hmask_offset }; + } + + static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / QK_K)); + auto total_qs_bytes = nblocks * (QK_K / 4) + nblocks * (QK_K / 8); + return { total_qs_bytes + block_index * 12, + total_qs_bytes + nblocks * 12 + block_index * sizeof(ggml_half) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + template <> struct block_q_t<GGML_TYPE_Q4_K> { struct traits { static constexpr uint32_t qk = QK_K; diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 16b2d65d271..4b58b09ab2c 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -394,6 +394,41 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> { } }; +template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K> { + static constexpr ggml_type gtype = GGML_TYPE_Q3_K; + + using q3_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q3_K>; + using q3_k_traits = typename q3_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset, + const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast<const uint8_t *>(vbq); + const uint8_t * qs = base + ibx_offset.first; + const uint8_t * hmask = base + ibx_offset.second; + const uint8_t * scales = base + d_offset.first; + const ggml_half d = *reinterpret_cast<const ggml_half *>(base + d_offset.second); + + const int bq8_offset = QR3_K * (iqs / (QI3_K / 2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2); + + const int vl = get_int_from_uint8(qs, iqs); + const int vh = ~get_int_from_uint8(hmask, iqs % (QI3_K / 2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int8_t * quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + u[i] = get_int_from_int8_aligned(quant_base_ptr, iqs % QI8_1); + d8[i] = (*(q8_1_ds + bq8_offset + i))[0]; + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, scales, scale_offset, static_cast<float>(d), d8); + } +}; + static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 65785ae4566..2d9e85794ad 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -79,6 +79,12 @@ if (Vulkan_FOUND) "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" ) + test_shader_extension_support( + "GL_NV_cooperative_matrix_decode_vector" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp" + "GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT" + ) + test_shader_extension_support( "GL_EXT_integer_dot_product" "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cbff7aad5b6..ac658951413 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -21,6 +21,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include <vulkan/vulkan.hpp> +// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the +// installed Vulkan headers predate the extension. +#ifndef VK_NV_cooperative_matrix_decode_vector +#define VK_NV_cooperative_matrix_decode_vector 1 +#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000) +typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { + VkStructureType sType; + void* pNext; + VkBool32 cooperativeMatrixDecodeVector; +} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV; +#endif + // SPIR-V Headers: different SDK installations expose different include paths. // LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>. // Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>. @@ -49,8 +62,10 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include <map> #include <set> #include <unordered_map> +#include <shared_mutex> #include <mutex> #include <future> +#include <condition_variable> #include <thread> #if defined(_MSC_VER) @@ -98,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { } VkPhysicalDeviceShaderBfloat16FeaturesKHR; #endif +#if !defined(VK_VALVE_shader_mixed_float_dot_product) +#define VK_VALVE_shader_mixed_float_dot_product 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000) +typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE { + VkStructureType sType; + void* pNext; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat32; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat16; + VkBool32 shaderMixedFloatDotProductBFloat16Acc; + VkBool32 shaderMixedFloatDotProductFloat8AccFloat32; +} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE; +#endif + #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } @@ -145,8 +175,9 @@ struct vk_pipeline_struct { uint32_t align; // true if fields have been set by ggml_vk_create_pipeline bool initialized {}; - // set to true to request the pipeline is compiled - std::atomic<bool> needed {}; + // true while a compile is in flight, used to dedupe concurrent claims. + // Protected by device->compile_mutex. + bool compile_pending {}; // set to true when the shader has been compiled std::atomic<bool> compiled {}; // number of registers used, extracted from pipeline executable properties @@ -398,6 +429,7 @@ enum vk_conv_shapes { CONV_SHAPE_128x128, CONV_SHAPE_64x32, CONV_SHAPE_32x256, + CONV_SHAPE_64x128, CONV_SHAPE_COUNT, }; @@ -412,6 +444,7 @@ vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = { { 128, 128, 16 }, // CONV_SHAPE_128x128 { 64, 32, 32 }, // CONV_SHAPE_64x32 { 32, 256, 16 }, // CONV_SHAPE_32x256 + { 64, 128, 16 }, // CONV_SHAPE_64x128 }; enum dmmv_wg_sizes { @@ -447,14 +480,16 @@ struct vk_fa_pipeline_state { }; struct vk_conv2d_pipeline_state { - vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH) - : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {} + vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH, uint32_t aligned) + : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH), aligned(aligned) {} uint32_t s0, s1, p0, p1, d0, d1, KW, KH; + // when set, shader can skip K/CRS/NPQ bounds checks and address clamps + uint32_t aligned; bool operator<(const vk_conv2d_pipeline_state &b) const { - return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) < - std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH); + return std::tie(s0, s1, p0, p1, d0, d1, KW, KH, aligned) < + std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH, b.aligned); } }; @@ -601,6 +636,14 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie struct vk_device_struct { std::recursive_mutex mutex; + mutable std::shared_mutex pinned_memory_mutex; + + // Guards compile_pending, all_pipelines, and the dynamic pipeline maps + // (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile + // runs with no lock held, so different pipelines can compile in parallel. + // Lock order is device->mutex -> compile_mutex, never the reverse. + std::mutex compile_mutex; + std::condition_variable compile_cv; vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; @@ -674,6 +717,10 @@ struct vk_device_struct { uint32_t coopmat_int_k; bool coopmat2; + bool coopmat2_bf16_support {}; + bool coopmat2_decode_vector; + + bool dot2_f16 {}; bool pipeline_executable_properties_support {}; @@ -764,7 +811,8 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; - vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i16; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; @@ -841,6 +889,7 @@ struct vk_device_struct { vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_fwht_f32[4]; vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_cumsum_small_f32; vk_pipeline pipeline_cumsum_multipass1_f32; @@ -1131,6 +1180,13 @@ struct vk_op_push_constants { float param4; }; +struct vk_op_fwht_push_constants { + uint32_t n_rows; + uint32_t src_offset; + uint32_t dst_offset; + float scale; +}; + struct vk_op_count_experts_push_constants { uint32_t ne00; uint32_t ne01; @@ -1699,7 +1755,7 @@ struct ggml_vk_garbage_collector { }; static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx); -static void ggml_vk_load_shaders(vk_device& device); +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr); static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx); static bool vk_memory_logger_enabled = false; @@ -1937,6 +1993,9 @@ struct ggml_backend_vk_context { // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. vk_pipeline_struct * prealloc_y_last_pipeline_used {}; const ggml_tensor * prealloc_y_last_tensor_used {}; + // True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback. + // If false, then it's contiguous. + bool prealloc_y_last_decode_vector_staging {}; // Track which nodes have been used since the last sync, and whether they were written to std::vector<const ggml_tensor *> unsynced_nodes_written; @@ -2036,6 +2095,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + struct ggml_backend_vk_buffer_context { vk_device_ref device; vk_buffer dev_buffer; @@ -2076,9 +2144,9 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); std::string type = device ? "device" : "host"; auto it = allocations.find(buf->buffer); - total_device -= device ? it->second : 0; - total_host -= device ? 0 : it->second; if (it != allocations.end()) { + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); allocations.erase(it); } else { @@ -2157,10 +2225,135 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { ctx->device->device.resetFences({ ctx->fence }); } -// variables to track number of compiles in progress -static uint32_t compile_count = 0; -static std::mutex compile_count_mutex; -static std::condition_variable compile_count_cond; +static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367; +static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447; +static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4; + +// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it +// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the +// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the +// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction. +// Returns true when the input used the extension (and `out` was populated with a +// stripped copy); returns false otherwise without touching `out`. +static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) { + static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector"; + + if (word_count < 5) { + return false; + } + + bool uses_decode_vector = false; + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast<const char *>(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + uses_decode_vector = true; + break; + } + } + pos += wc; + } + + if (!uses_decode_vector) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector"); + + // Bulk-copy unchanged runs and only break the run when an instruction needs to + // be dropped or patched. Use reserve + insert/push_back so the destination buffer + // is touched exactly once (no zero-initialization pass from resize()). + out.clear(); + out.reserve(word_count); + + size_t run_start = 0; + auto flush_run = [&](size_t up_to) { + if (up_to > run_start) { + out.insert(out.end(), code + run_start, code + up_to); + } + }; + + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast<const char *>(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + } + + if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + + if (op == kSpvOpCooperativeMatrixLoadTensorNV) { + // [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...] + GGML_ASSERT(wc >= 8); + + uint32_t mem_mask = code[pos + 6]; + size_t cur = pos + 7; + // Each of these MemoryAccess bits (when set) carries one trailing operand. + cur += (mem_mask & 0x2) ? 1 : 0; // Aligned + cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable + cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible + cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask + cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask + GGML_ASSERT(cur < pos + wc); + + uint32_t ta_mask = code[cur]; + if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) { + pos += wc; + continue; // leave instruction inside the current unchanged run + } + + flush_run(pos); + + // Append unchanged prefix of the instruction (header through the mem-extras). + size_t inst_start = out.size(); + size_t pre_n = cur - pos; + out.insert(out.end(), code + pos, code + pos + pre_n); + + // Emit TA mask with the DecodeVectorFunc bit cleared. + out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit); + + // TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim; + // DecodeVectorFunc (0x4) is dropped along with its trailing id operand. + size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0); + if (keep_ta_extras) { + out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras); + } + + GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1); + + // Patch the instruction header with the new (one-shorter) word count. + uint32_t new_wc = wc - 1; + out[inst_start] = (new_wc << spv::WordCountShift) | op; + + pos += wc; + run_start = pos; + continue; + } + + pos += wc; + } + + flush_run(word_count); + return true; +} static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, @@ -2233,6 +2426,18 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (device->coopmat2 && !device->coopmat2_decode_vector) { + const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data(); + size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size(); + std::vector<uint32_t> stripped; + if (ggml_vk_strip_decode_vector(src, src_n, stripped)) { + spirv = std::move(stripped); + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); + } + } +#endif + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); vk::PushConstantRange pcr( @@ -2314,7 +2519,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::cerr << "ggml_vulkan: " << e.what() << std::endl; throw e; } - pipeline->compiled = true; if (vk_instance.debug_utils_support) { vk::DebugUtilsObjectNameInfoEXT duoni; @@ -2363,14 +2567,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } } - device->all_pipelines.push_back(pipeline); - { - std::lock_guard<std::mutex> guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; + std::lock_guard<std::mutex> guard(device->compile_mutex); + device->all_pipelines.push_back(pipeline); + pipeline->compiled = true; + pipeline->compile_pending = false; } - compile_count_cond.notify_all(); + device->compile_cv.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -2386,8 +2589,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { - pipeline->needed = true; - ggml_vk_load_shaders(ctx->device); + ggml_vk_load_shaders(ctx->device, pipeline); } ggml_pipeline_allocate_descriptor_sets(ctx); } @@ -2961,7 +3163,7 @@ struct vk_fa_tuning_params { }; static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { @@ -3102,6 +3304,13 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) { + path = FA_COOPMAT1; + } + if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) { + path = FA_SCALAR; + } + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 path = FA_SCALAR; @@ -3111,7 +3320,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type); if (!shape_ok || !shmem_ok) { path = FA_SCALAR; @@ -3157,8 +3366,8 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) { const auto fa_block_bytes = [](ggml_type t) -> uint32_t { - // decodeBufF32 uses a block of vec4s for a better memory access pattern. - return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + if (t == GGML_TYPE_F32) return 16u; + return (uint32_t) ggml_type_size(t); }; return { /* 0 WorkGroupSize */ state.workgroup_size, @@ -3186,7 +3395,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec switch (src0_type) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: - lut_size = 2*2048 + 4*2048; + // Regular matmul uses the compact uint16_t IQ1 grid; the expanded + // uint32_t grid is only enabled for the q8_1/int-dot vector path. + lut_size = 2*2048; break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; @@ -3380,10 +3591,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type #endif } -static void ggml_vk_load_shaders(vk_device& device) { +// load_shaders walks the pipeline list under compile_mutex and either claims +// the requested pipeline for compilation or, if another thread is already +// compiling it, drops the lock and waits on compile_cv. Compiles themselves +// run unlocked. +struct CompileTask { + vk_pipeline pipeline; + size_t spv_size; + const void * spv_data; + std::string entrypoint; + uint32_t parameter_count; + std::array<uint32_t, 3> wg_denoms; + std::vector<uint32_t> specialization_constants; + bool disable_robustness; + bool require_full_subgroups; + uint32_t required_subgroup_size; +}; + +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); // some shaders have a minimum subgroup size const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); @@ -3413,6 +3640,15 @@ static void ggml_vk_load_shaders(vk_device& device) { l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; uint32_t l_align, m_align, s_align; + + vk_pipeline wait_pipeline; + CompileTask claimed_task {}; + bool has_claimed_task = false; + + // The rest of the walk reads and writes shared device state, so hold the + // lock until we're done deciding what to compile. + std::unique_lock<std::mutex> compile_lock(device->compile_mutex); + if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id l_warptile = { 256, 128, 256, 64, 1 }; @@ -3439,9 +3675,10 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size }; - m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; - s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; + const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u; + l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -3598,7 +3835,6 @@ static void ggml_vk_load_shaders(vk_device& device) { device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>(); } - std::vector<std::future<void>> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -3632,23 +3868,33 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif } - if (!pipeline->needed || pipeline->compiled) { + // We only care about the pipeline this call asked for; the rest + // (including the 64-bit indexing variant) are handled by their + // own request_descriptor_sets / load_shaders calls. + if (pipeline.get() != requested.get()) { continue; } - // TODO: We're no longer benefitting from the async compiles (shaders are - // compiled individually, as needed) and this complexity can be removed. - { - // wait until fewer than N compiles are in progress - uint32_t N = std::max(1u, std::thread::hardware_concurrency()); - std::unique_lock<std::mutex> guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; + + if (pipeline->compiled) { + continue; } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, - parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + wait_pipeline = pipeline; + + if (!pipeline->compile_pending) { + pipeline->compile_pending = true; + claimed_task.pipeline = pipeline; + claimed_task.spv_size = spv_size; + claimed_task.spv_data = spv_data; + claimed_task.entrypoint = entrypoint; + claimed_task.parameter_count = parameter_count; + claimed_task.wg_denoms = wg_denoms; + claimed_task.specialization_constants = specialization_constants; + claimed_task.disable_robustness = disable_robustness; + claimed_task.require_full_subgroups = require_full_subgroups; + claimed_task.required_subgroup_size = required_subgroup_size; + has_claimed_task = true; + } } }; @@ -3672,10 +3918,16 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); const void * spv_data = nullptr; size_t spv_size = 0; - if (use_mmq) { + const char *name = nullptr; + if (bf16_kv) { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16"; + } else if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->fp16) { if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } @@ -3685,16 +3937,22 @@ static void ggml_vk_load_shaders(vk_device& device) { spv_size = flash_attn_f32_f16_fp32_int8_len; } #endif + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } else { if (device->fp16) { - if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } - else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + if (device->dot2_f16) { + if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; } + else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } } else { spv_data = flash_attn_f32_f16_fp32_data; spv_size = flash_attn_f32_f16_fp32_len; } + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } - const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3712,11 +3970,25 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; + const void * spv_data; size_t spv_size; - if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } - else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } - const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + const char *name; + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm1_data; + spv_size = flash_attn_f32_f16_bf16_cm1_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1"; +#else + continue; +#endif + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + } ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3734,10 +4006,20 @@ static void ggml_vk_load_shaders(vk_device& device) { const bool aligned = fa.first.aligned; const bool f32acc = fa.first.f32acc; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const void * spv_data; size_t spv_size; const char * name; - if (aligned) { + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat2_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm2_data; + spv_size = flash_attn_f32_f16_bf16_cm2_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2"; +#else + continue; +#endif + } else if (aligned) { if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } } else { @@ -3958,7 +4240,23 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} + // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + + // bf16 scalar path promotes to f32, no dot2 variant +#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ @@ -3993,7 +4291,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4001,7 +4299,6 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4041,8 +4338,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -4087,8 +4383,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4133,6 +4428,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM +#undef CREATE_MM_NODOT2 } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ @@ -4709,9 +5005,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i32, "repeat_i32", repeat_i32_len, repeat_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i16, "repeat_i16", repeat_i16_len, repeat_i16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_UNARY(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -4824,6 +5122,24 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + // Intel Arc B390 was observed segfaulting with this shader. + if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) { + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + if (device->subgroup_size <= n) { + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size); + } + ++idx; + } + } else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) { + // Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147 + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + const uint32_t block_size = std::min(device->subgroup_size, n); + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1); + ++idx; + } + } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); @@ -4939,7 +5255,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // conv2d, conv_transpose_2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { - uint32_t conv2d_WG_SIZE = 256; + // smaller WG for the small-tile fallback gives more concurrent WGs per SM + uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256; uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8; uint32_t conv2d_SHMEM_PAD = 4; @@ -4978,18 +5295,77 @@ static void ggml_vk_load_shaders(vk_device& device) { conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. } - uint32_t conv2d_shmem_req = - (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float); - if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + // cm1 is used only when cm2 is unavailable; capped at 64x128 (due to shared memory size). + // Requires 16x16x16 f16-acc since that's the fragment shape hard-coded in the shader. + // Subgroup size must be 32 or 64 (to keep WG_SIZE sane) and we need + // subgroup_size_control to force the driver to actually use it. + bool conv2d_use_cm1 = false; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + conv2d_use_cm1 = !device->coopmat2 && + device->coopmat_support && device->coopmat_support_16x16x16_f16acc && + device->subgroup_size_control && + (device->subgroup_size == 32 || device->subgroup_size == 64) && + s != CONV_SHAPE_128x128; +#endif + + const uint32_t conv2d_cm1_shmem_pad = 8; + + auto shmem_req = [&](uint32_t pad, bool csh_store, bool fp16_shmem) { + const uint32_t elem_size = fp16_shmem ? (uint32_t)sizeof(uint16_t) : (uint32_t)sizeof(float); + const uint32_t csh_elems = csh_store ? conv2d_BS.K * conv2d_BS.NPQ : 0u; + return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size; + }; + + // coopmat1 needs to store the output through shared memory, so check up front + // whether it'll fit and disable it before applying coopmat1 parameters. + if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) { + conv2d_use_cm1 = false; + } + + uint32_t conv2d_WM = 16, conv2d_WN = 16; // cm1 subgroup tile, ignored otherwise + if (conv2d_use_cm1) { + conv2d_SHMEM_PAD = conv2d_cm1_shmem_pad; + // 16x16x16 fragments; pick WM/WN to keep WG_SIZE at 256 + // (i.e. 8 subgroups for sg=32, 4 subgroups for sg=64). + const bool sg64 = (device->subgroup_size == 64); + switch (s) { + case CONV_SHAPE_64x32: conv2d_WM = sg64 ? 32 : 16; conv2d_WN = 16; break; + case CONV_SHAPE_64x128: conv2d_WM = 32; conv2d_WN = sg64 ? 64 : 32; break; + case CONV_SHAPE_32x256: conv2d_WM = sg64 ? 16 : 32; conv2d_WN = sg64 ? 128 : 32; break; + default: break; + } + const uint32_t warps_M = conv2d_BS.K / conv2d_WM; + const uint32_t warps_N = conv2d_BS.NPQ / conv2d_WN; + conv2d_WG_SIZE = warps_M * warps_N * device->subgroup_size; + } + + // stage cm2 accumulator through shmem for coalesced global stores; + // skipped on 128x128 where the extra Csh footprint hurts occupancy. + // cm1 always uses the staged path. + uint32_t conv2d_csh_store = (device->coopmat2 && s != CONV_SHAPE_128x128) ? 1u : 0u; + if (conv2d_use_cm1) { + conv2d_csh_store = 1; + } + + // shmem is fp16 on cm2/cm1 (matches Csh), fp32 on scalar + const bool conv2d_use_fp16_shmem = device->coopmat2 || conv2d_use_cm1; + + // shrink CRS if the non-cm1 config still doesn't fit + if (device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_SHMEM_PAD, conv2d_csh_store, conv2d_use_fp16_shmem)) { + GGML_ASSERT(!conv2d_use_cm1); conv2d_BS.CRS = 8; if (use_collectives) { conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS); } + conv2d_csh_store = 0; } std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 }; std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + // cm1 needs a fixed subgroup width to match the WG_SIZE we computed + const uint32_t conv2d_required_subgroup_size = conv2d_use_cm1 ? device->subgroup_size : 0; + #define CREATE_CONV(name, type_suffix, spv_suffix) \ for (auto &c : device->pipeline_##name##type_suffix[s]) { \ const vk_conv2d_pipeline_state &state = c.first; \ @@ -5002,10 +5378,14 @@ static void ggml_vk_load_shaders(vk_device& device) { spec_constants_cpy.push_back(state.d1); \ spec_constants_cpy.push_back(state.KW); \ spec_constants_cpy.push_back(state.KH); \ + spec_constants_cpy.push_back(state.aligned); \ + spec_constants_cpy.push_back(conv2d_csh_store); \ + spec_constants_cpy.push_back(conv2d_WM); \ + spec_constants_cpy.push_back(conv2d_WN); \ ggml_vk_create_pipeline( \ device, c.second, #name #type_suffix, \ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \ + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives || conv2d_required_subgroup_size, conv2d_required_subgroup_size); \ } #define CREATE_CONVS(spv_suffix) \ CREATE_CONV(conv2d, _f32, spv_suffix) \ @@ -5016,6 +5396,11 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->coopmat2) { CREATE_CONVS(_cm2) } else +#endif +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (conv2d_use_cm1) { + CREATE_CONVS(_cm1) + } else #endif if (conv2d_UNROLL) { CREATE_CONVS(_unroll) @@ -5037,8 +5422,25 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - for (auto &c : compiles) { - c.wait(); + // Drop compile_mutex so other threads can walk while we compile. + compile_lock.unlock(); + + // Compile what we claimed; create_pipeline_func reacquires compile_mutex + // at the end to flip compile_pending/compiled and notify waiters. + if (has_claimed_task) { + auto & task = claimed_task; + ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data, + task.entrypoint, task.parameter_count, task.wg_denoms, + task.specialization_constants, task.disable_robustness, + task.require_full_subgroups, task.required_subgroup_size); + } + + // Another thread may be compiling the pipeline we need; block on it here. + if (wait_pipeline) { + std::unique_lock<std::mutex> wait_lock(device->compile_mutex); + device->compile_cv.wait(wait_lock, [&] { + return wait_pipeline->compiled.load(); + }); } } @@ -5088,11 +5490,13 @@ static vk_device ggml_vk_get_device(size_t idx) { bool amd_shader_core_properties2 = false; bool pipeline_robustness = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; device->shader_64b_indexing = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -5122,6 +5526,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -5132,6 +5539,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { pipeline_executable_properties_support = true; } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && @@ -5279,6 +5689,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); +#ifdef __APPLE__ + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_shuffle = false; + } +#endif device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); @@ -5399,6 +5814,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME); + } + #if defined(VK_KHR_shader_bfloat16) VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; bfloat16_features.pNext = nullptr; @@ -5426,6 +5849,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product"); + } + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; if (pipeline_executable_properties_support) { @@ -5460,6 +5891,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->bf16 = false; #endif + device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && @@ -5518,46 +5951,73 @@ static vk_device ggml_vk_get_device(size_t idx) { found_fp16_256 = false, found_fp32_128 = false, found_fp32_256 = false; + bool found_bf16_128 = false, + found_bf16_256 = false; // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 // with 32x16x16 and 256 with 32x32x16. for (auto &prop : flexible_dimensions) { if (prop.saturatingAccumulation == VK_FALSE && - prop.scope == VK_SCOPE_WORKGROUP_KHR && - prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - - if (prop.workgroupInvocations == 128 && - prop.MGranularity <= 32 && - prop.NGranularity <= 16 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_128 = true; + prop.scope == VK_SCOPE_WORKGROUP_KHR) { + + if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_128 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } } } - if (prop.workgroupInvocations == 256 && - prop.MGranularity <= 32 && - prop.NGranularity <= 32 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_256 = true; + +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + found_bf16_128 = true; } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_256 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + found_bf16_256 = true; } } +#endif } } if (found_fp16_128 && found_fp16_256 && found_fp32_128 && found_fp32_256 && coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { device->coopmat2 = true; + device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256; + device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; } } #endif @@ -5747,6 +6207,19 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } +#if VK_HEADER_VERSION >= 287 + // Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE. + // Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case. + if (device->driver_id == vk::DriverId::eMesaHoneykrisp) { + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + } +#endif + device->mul_mat_l_int[i] = device->mul_mat_l[i]; device->mul_mat_m_int[i] = device->mul_mat_m[i]; device->mul_mat_s_int[i] = device->mul_mat_s[i]; @@ -5773,8 +6246,12 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); - // Only use transfer queue on AMD non-GCN, when the graphics queue is not enabled - const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !allow_graphics_queue; + // Prefer a dedicated transfer queue on AMD dGPUs (non-GCN) when graphics queue use is disabled. + const bool prefers_transfer_queue = + device->vendor_id == VK_VENDOR_ID_AMD && + device->architecture != AMD_GCN && + !device->uma && + !allow_graphics_queue; if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; @@ -5840,8 +6317,10 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool fp16_compute = false; bool coopmat_support = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool integer_dot_product = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -5858,6 +6337,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -5868,6 +6350,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } } @@ -5942,6 +6427,29 @@ static void ggml_vk_print_gpu_info(size_t idx) { } #endif +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + } +#endif + + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + } + + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + } + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; @@ -5966,11 +6474,34 @@ static void ggml_vk_print_gpu_info(size_t idx) { #endif && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); - std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + coopmat2_support = coopmat2_support && + coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads; +#else + coopmat2_support = false; +#endif + + coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; +#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + coopmat2_decode_vector_support = false; +#endif + + std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2") + : coopmat_support ? "KHR_coopmat" + : "none"; + + bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size, props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { @@ -6656,7 +7187,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { return nullptr; } - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); return buf->ptr; @@ -6667,7 +7198,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { return; } VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); vk_buffer buf; size_t index; @@ -6691,7 +7222,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { } static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::shared_lock<std::shared_mutex> guard(device->pinned_memory_mutex); buf = nullptr; buf_offset = 0; for (size_t i = 0; i < device->pinned_memory.size(); i++) { @@ -6812,13 +7343,6 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx->s->buffer->buf.dispatch(wg0, wg1, wg2); } -static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) { - s.buffer->buf.end(); - - s.wait_semaphores = std::move(wait_semaphores); - s.signal_semaphores = std::move(signal_semaphores); -} - static void ggml_vk_ctx_end(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); if (ctx->s == nullptr) { @@ -6971,7 +7495,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; for (uint64_t i0 = 0; i0 < ne0; i0++) { - slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 }); } } } @@ -7098,8 +7622,12 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); + if (width == spitch && width == dpitch) { + memcpy((uint8_t *)dst->ptr + offset, src, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); + } } } else { std::lock_guard<std::recursive_mutex> guard(dst->device->mutex); @@ -7218,8 +7746,12 @@ static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, si if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + if (width == spitch && width == dpitch) { + memcpy(dst, (const uint8_t *) src->ptr + offset, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + } } } else { std::lock_guard<std::recursive_mutex> guard(src->device->mutex); @@ -7684,6 +8216,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_sync_buffers(ctx, subctx); } +// Copy/convert tensor into a caller-defined dense layout. Destination strides +// are in output elements, not bytes. +static void ggml_vk_cpy_to_strided( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, + const vk_subbuffer & in, const vk_subbuffer & out, + uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) { + VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array<uint32_t, 3> elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); +} + static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { switch(type) { case GGML_TYPE_Q8_1: @@ -7941,24 +8507,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -7999,8 +8569,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } - // General performance issue with q3_k and q6_k due to 2-byte alignment - if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + // q6_k only has 2-byte alignment which makes it somewhat problematic, + // using MMVQ is only a win on Intel. + bool mmvq_q6 = device->vendor_id == VK_VENDOR_ID_INTEL; + if (src0_type == GGML_TYPE_Q6_K && !mmvq_q6) { return false; } @@ -8012,7 +8584,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: - if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { return true; } @@ -8039,9 +8611,16 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (device->architecture == vk_device_architecture::INTEL_XE2) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return true; + } + } + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { - // Intel Windows proprietary driver MMVQ performance is worse than fp16, see - // https://github.com/ggml-org/llama.cpp/issues/17628 + // Intel Windows proprietary driver MMVQ performance for !Q2/Q3/Q6 is worse than fp16, + // see https://github.com/ggml-org/llama.cpp/issues/17628 and + // https://github.com/ggml-org/llama.cpp/pull/23056 return false; } @@ -8207,24 +8786,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -8489,6 +9072,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } +static int ggml_vk_fwht_pipeline_idx(int64_t n) { + switch (n) { + case 64: return 0; + case 128: return 1; + case 256: return 2; + case 512: return 3; + default: return -1; + } +} + +static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) { + if (ctx->num_additional_fused_ops != 0) { + return false; + } + + if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) { + return false; + } + + const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]); + if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) { + return false; + } + + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src1)) { + return false; + } + GGML_ASSERT(ggml_is_contiguous(dst)); + + return true; +} + +static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) { + const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]); + vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx]; + + const uint32_t rows_per_workgroup = 4; + const uint32_t n_rows = (uint32_t)ggml_nrows(src); + const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup); + const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true); + const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); + + vk_op_fwht_push_constants pc = { + n_rows, + 0, + 0, + 1.0f / std::sqrt((float)src->ne[0]), + }; + init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 }); +} + static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; @@ -8522,6 +9167,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c m_offset += cur_M_size; } + } else if (ggml_vk_can_use_fwht(ctx, src1, dst)) { + ggml_vk_fwht(ctx, subctx, src1, dst); } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && // detect 0213 permutation, and batch size of 1 src0->nb[0] <= src0->nb[2] && @@ -8611,12 +9258,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + // B must already be, or be convertible to, the matmul B type used by this path. + const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector && + (f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) && + (src1->type == GGML_TYPE_F32 || src1->type == f16_type); + // If B is copied to prealloc_y, we can choose a 4-element-aligned row stride. + const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type; + // Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned. + const bool y_decode_vector_aligned = + (ne10 % 4 == 0) && + (y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0); + // Stage B only when decode-vector is available and direct B reads would be misaligned. + const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned; +#else + const bool y_decode_vector_staging = false; +#endif + const bool y_non_contig = y_decode_vector_staging || + (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); - // If src0 is BF16, try to use a BF16 x BF16 multiply - ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10; const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -8655,11 +9320,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = padded_n * ne10 * ne12 * ne13; + const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13; const uint64_t d_ne = ggml_nelements(dst); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); - const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t ids_sz = nbi2; @@ -8669,13 +9334,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; vk_pipeline to_q8_1 = nullptr; + auto make_y_staged_dst = [&]() { + ggml_tensor y_staged_dst = *src1; + y_staged_dst.type = f16_type; + y_staged_dst.nb[0] = ggml_type_size(f16_type); + y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride; + y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n; + y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2]; + return y_staged_dst; + }; + if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + ggml_tensor y_staged_dst; + const ggml_tensor * y_staged_dst_ptr = nullptr; + if (y_decode_vector_staging) { + y_staged_dst = make_y_staged_dst(); + y_staged_dst_ptr = &y_staged_dst; + } + + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -8793,30 +9475,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + if (y_decode_vector_staging) { + const ggml_tensor y_staged_dst = make_y_staged_dst(); + const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type); + ggml_vk_cpy_to_strided( + ctx, subctx, to_fp16_vk_1, src1, + ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), + (uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size)); + } else { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + } ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } ggml_vk_sync_buffers(ctx, subctx); uint32_t stride_batch_x = ne00*ne01; - uint32_t stride_batch_y = ne10*ne11; + uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10; + uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11; if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); @@ -8831,7 +9530,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf, - ne01, ne21, ne10, ne10, ne10, ne01, + ne01, ne21, ne10, ne10, stride_b_y, ne01, stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT @@ -8989,24 +9688,28 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -9105,7 +9808,8 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; - const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + // BF16 uses the fp32 shader (FLOAT_TYPE=float) + const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float); const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); @@ -9146,7 +9850,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) { // Needs to be kept up to date on shader changes const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; @@ -9176,8 +9880,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t vsh_stride = MatBc / 4 * row_split; const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; + // BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes) + const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4; const uint32_t osh_stride = params.row_split * MatBr / 4; - const uint32_t pvsh = MatBc * osh_stride * f16vec4; + const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size; const uint32_t slope = Br * acctype; @@ -9246,7 +9952,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). @@ -9308,7 +10014,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { @@ -9372,13 +10078,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline_fa_mask_opt = nullptr; if (use_mask_opt) { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_fa_mask_opt; - auto it = pipelines.find({Br, Bc}); - if (it != pipelines.end()) { - pipeline_fa_mask_opt = it->second; - } else { - pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>(); + { + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>(); + } } assert(pipeline_fa_mask_opt); ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); @@ -9489,10 +10197,23 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u // so small convolutions will still choose a smaller tile. const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; - if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { + // 128x128 isn't used with cm1 due to shared memory size; fall through to a smaller tile. + bool allow_128x128 = true; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (!ctx->device->coopmat2 && ctx->device->coopmat_support && ctx->device->coopmat_support_16x16x16_f16acc) { + allow_128x128 = false; + } +#endif + + if (allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { return CONV_SHAPE_128x128; } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) { return CONV_SHAPE_32x256; + } else if (K <= 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + return CONV_SHAPE_64x128; + } else if (!allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + // cm1 fallback for large K when 128x128 isn't available + return CONV_SHAPE_64x128; } else { return CONV_SHAPE_64x32; } @@ -9664,7 +10385,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_REPEAT: if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { - return ctx->device->pipeline_repeat_f32; + return ctx->device->pipeline_repeat_i32; + } + if (ggml_type_size(src0->type) == 2 && ggml_type_size(dst->type) == 2) { + return ctx->device->pipeline_repeat_i16; } return nullptr; case GGML_OP_REPEAT_BACK: @@ -9900,7 +10624,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); if (it != ctx->device->pipeline_solve_tri_f32.end()) { pipeline = it->second; @@ -10028,7 +10752,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0; uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1; uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1; - vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH); + + // tile-aligned shapes let the shader skip bounds checks + const uint32_t Cin = (uint32_t)src1->ne[2]; + const uint32_t CRS = Cin * KW * KH; + const uint32_t BS_K = vk_conv_block_sizes[shape].K; + const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS; + const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ; + const uint32_t aligned = ((K % BS_K == 0) && + (CRS % BS_CRS == 0) && + (NPQ % BS_NPQ == 0)) ? 1u : 0u; + + vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH, aligned); std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr; if (op == GGML_OP_CONV_2D) { @@ -10048,7 +10783,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = pipelines->find(conv2d_pipeline_state); if (it != pipelines->end()) { pipeline = it->second; @@ -10843,8 +11578,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const uint32_t K = (uint32_t)src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0); const uint32_t s_off = S_v * H * n_tokens * n_seqs; @@ -13204,7 +13939,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -13784,6 +14521,8 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); ctx->prealloc_y_last_pipeline_used = {}; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); @@ -13834,6 +14573,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->sync_staging); ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; @@ -14085,12 +14826,6 @@ static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_ty UNUSED(buft); } -static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { - return GGML_VK_NAME "_Host"; - - UNUSED(buffer); -} - static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); @@ -15027,6 +15762,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; if (ctx->prealloc_size_add_rms_partials) { ggml_vk_preallocate_buffers(ctx, nullptr); @@ -16048,6 +16784,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (t) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_0: @@ -16063,6 +16800,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } + if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) { + return false; + } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll return false; @@ -16185,7 +16925,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } case GGML_OP_REPEAT: - return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + return ggml_type_size(op->type) == ggml_type_size(op->src[0]->type) && + (ggml_type_size(op->type) == sizeof(float) || ggml_type_size(op->type) == 2); case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: @@ -17267,7 +18008,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[4], src_clone[5], src_clone[6]); } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], - src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + src_clone[2], src_clone[3], src_clone[4], src_clone[5], + ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index e1f613fb4f6..10a9ea21025 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -11,6 +11,10 @@ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) message(STATUS "Enabling coopmat2 glslc support") endif() +if (GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 decode_vector glslc support") +endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) message(STATUS "Enabling dot glslc support") diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 875c012cd3b..1428ef68d81 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -7,6 +7,13 @@ #extension GL_KHR_memory_scope_semantics : enable #endif +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + #ifdef USE_COLLECTIVES # extension GL_KHR_shader_subgroup_shuffle : enable #endif @@ -77,6 +84,39 @@ layout(constant_id = 12) const uint d1 = 1; // Kernel spatial sizes layout(constant_id = 13) const uint KW = 1; layout(constant_id = 14) const uint KH = 1; +// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned) +layout(constant_id = 15) const uint aligned = 0; +// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this. +layout(constant_id = 16) const uint csh_store = 0; + +#ifdef COOPMAT +// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of +// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN == +// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size. +layout(constant_id = 17) const uint WM = 32; +layout(constant_id = 18) const uint WN = 32; +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; +const uint cms_per_row = WM / TM; +const uint cms_per_col = WN / TN; +const uint warps_M = BS_K / WM; +const uint warps_N = BS_NPQ / WN; +#endif + +// without padding, H_idx/W_idx are in bounds by construction (non-TRANSPOSE only) +#ifdef TRANSPOSE +const bool hw_in_bounds = false; +#else +const bool hw_in_bounds = (p0 == 0) && (p1 == 0); +#endif + +// TRANSPOSE stride alignment is trivially satisfied for stride 1 +#ifdef TRANSPOSE +const bool stride_in_bounds = (s0 == 1) && (s1 == 1); +#else +const bool stride_in_bounds = true; +#endif uint32_t tid = gl_LocalInvocationID.x; const uint32_t WG_SIZE = gl_WorkGroupSize.x; @@ -94,7 +134,7 @@ uint32_t n_elems_out = K * NPQ; // Number of blocktiles per input uint32_t NB_CRS = splitWork(CRS, BS_CRS); -#ifdef COOPMAT2 +#if defined(COOPMAT2) || defined(COOPMAT) #define SHMEM_TYPE float16_t #else #define SHMEM_TYPE float @@ -112,6 +152,17 @@ const uint32_t Bsh_len = BS_CRS * Bsh_stride; shared SHMEM_TYPE Ash[Ash_len]; // K x CRS shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ +#if defined(COOPMAT2) || defined(COOPMAT) +// stage matC through shmem so global stores are row-major (NPQ-contiguous) +const uint32_t Csh_stride = BS_NPQ; +#ifdef COOPMAT +const uint32_t Csh_len = BS_K * Csh_stride; +#else +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +#endif +shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ +#endif + // Threadtile sizes const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; @@ -161,7 +212,7 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = D_TYPE(elem); } return elem; @@ -176,6 +227,13 @@ void main() { #ifdef COOPMAT2 coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC; matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0); +#elif defined(COOPMAT) + coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col]; + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0); + } + const uint warp_r = gl_SubgroupID / warps_N; + const uint warp_c = gl_SubgroupID % warps_N; #else float regC[TS_K][TS_NPQ]; for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -228,12 +286,15 @@ void main() { uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ #ifdef TRANSPOSE - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03; #else - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03; #endif + if (aligned == 0) { + knl_idx = min(knl_idx, K * CRS - 1); + } float val = knl_data[knl_idx]; - if (K_idx >= K || CRS_idx_a >= CRS) { + if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) { val = 0.0; } Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); @@ -282,15 +343,27 @@ void main() { uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1; uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0; #endif - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + uint32_t src_idx = W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13; + // skip clamp when address can't go OOB + if (aligned == 0 || !hw_in_bounds || !stride_in_bounds) { + src_idx = min(max(src_idx, 0), p.Cin * p.N * p.W * p.H - 1); + } float val = src_data[src_idx]; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ - || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) + bool oob = false; + if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) { + oob = true; + } + // also catches lower-bound underflow (idx wraps to 0x80000000+) + if (!hw_in_bounds && (H_idx >= p.H || W_idx >= p.W)) { + oob = true; + } #ifdef TRANSPOSE - || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0) + if (!stride_in_bounds && + ((H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0))) { + oob = true; + } #endif - ) { + if (oob) { val = 0.0; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); @@ -303,6 +376,23 @@ void main() { coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); matC = coopMatMulAdd(matA, matB, matC); +#elif defined(COOPMAT) + // each subgroup multiplies its grid of fragments per TK-sized CRS chunk + [[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) { + coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row]; + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK; + coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b; + const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } #else if (T_y * TS_K < K) { UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { @@ -325,8 +415,51 @@ void main() { barrier(); } /* Save C* */ +#if defined(COOPMAT2) || defined(COOPMAT) + // stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores +#ifdef COOPMAT + const bool use_staged_store = true; +#else + const bool use_staged_store = (csh_store != 0); +#endif + if (use_staged_store) { +#ifdef COOPMAT + // cm1: each subgroup stores its fragment grid into its Csh slot + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN; + coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } +#else + coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); +#endif + barrier(); + + // cooperative shmem->global: WG threads spread across BS_NPQ (the + // contiguous direction of dst), each iter covers store_rows_per_iter K-rows + const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ; + const uint32_t store_iters = BS_K / store_rows_per_iter; + const uint32_t k_thread_offset = tid / BS_NPQ; + const uint32_t npq_thread = tid % BS_NPQ; + [[unroll]] for (uint32_t i = 0; i < store_iters; i++) { + uint32_t k_local = i * store_rows_per_iter + k_thread_offset; + uint32_t K_idx = B_idx_K * BS_K + k_local; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]); + } + } + } #ifdef COOPMAT2 - coopMatPerElementNV(matC, matC, perElemOpStore); + else { + coopMatPerElementNV(matC, matC, perElemOpStore); + } +#endif #else if (T_y * TS_K < K) { for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -337,7 +470,7 @@ void main() { uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = regC[T_ly][T_lx]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 88d07d2dfd5..e67299fdeca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -5,21 +5,60 @@ #include "types.glsl" #if defined(DATA_A_F32) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} + #endif #if defined(DATA_A_F16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const vec2 a = data_a_packed32[(a_offset + ib)/2]; + const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(a, b); +} #endif #if defined(DATA_A_BF16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return bf16_to_fp32(data_a[a_offset + ib]); +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]), + bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3])); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const uint a = data_a_packed32[(a_offset + ib)/2]; + const uint b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(uintBitsToFloat((a & 0x0000ffff) << 16), + uintBitsToFloat( a & 0xffff0000), + uintBitsToFloat((b & 0x0000ffff) << 16), + uintBitsToFloat( b & 0xffff0000)); +} #endif #if defined(DATA_A_Q4_0) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index c582aba87dc..7171cbfa559 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,4 +1,12 @@ +// Each format defines a scalar dequantFunc<T> plus a V=4 dequantFunc<T>_v +// passed as the optional vector decoder to coopMatLoadTensorNV via +// GL_NV_cooperative_matrix_decode_vector. When the driver doesn't support +// the extension, ggml-vulkan.cpp strips it from the compiled SPIR-V. +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif + #include "types.glsl" layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 { @@ -25,6 +33,19 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2 return bit != 0u ? d : -d; } +f16vec4 dequantFuncQ1_0_v(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t md = -d; + const uint idx = coordInBlock[1]; + const uint qs_nib = uint(bl.block.qs[idx >> 3]) >> (idx & 0x4u); + return f16vec4( + (qs_nib & 1u) != 0u ? d : md, + (qs_nib & 2u) != 0u ? d : md, + (qs_nib & 4u) != 0u ? d : md, + (qs_nib & 8u) != 0u ? d : md); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -42,10 +63,28 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_0_v(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xE) >> 1; // even, in {0,2,4,6} + const uint qsw = uint32_t(bl.block.qs[qs_i ]) + | (uint32_t(bl.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const uint q4 = (qsw >> shift) & 0x0F0F0F0Fu; + const u8vec4 q = unpack8(q4); + return f16vec4((vec4(q) - vec4(8.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { block_q4_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1_packed32 { + block_q4_1_packed32 block; +}; + float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -60,10 +99,27 @@ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_1_v(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_1_packed32 bl32 = decodeBufQ4_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4(vec4(q) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { block_q5_0 block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0_packed16 { + block_q5_0_packed16 block; +}; + float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -82,10 +138,32 @@ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_0_v(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_0_packed16 bl16 = decodeBufQ5_0_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint uint_qh = uint(bl16.block.qh[1]) << 16 | uint(bl16.block.qh[0]); + const uint qh_pack = uint_qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high) - vec4(16.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { block_q5_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1_packed32 { + block_q5_1_packed32 block; +}; + float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -105,6 +183,23 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_1_v(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_1_packed32 bl32 = decodeBufQ5_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint qh_pack = bl.block.qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high)) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { block_q8_0_packed16 block; }; @@ -121,6 +216,17 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ8_0_v(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint base = idx >> 1u; + const uint w = uint(uint16_t(bl.block.qs[base])) + | (uint(uint16_t(bl.block.qs[base + 1u])) << 16u); + const i8vec4 qi = unpack8(int32_t(w)); + return f16vec4(vec4(qi) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { block_q2_K block; }; @@ -129,6 +235,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 block_q2_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K_packed32 { + block_q2_K_packed32 block; +}; + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); @@ -147,10 +257,36 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ2_K_v(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed32 bl32 = decodeBufQ2_K_packed32(bl); + const f16vec2 dm = bl.block.dm; + const uint idx = coordInBlock[1]; + + const uint scalesi = idx >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + // qs_i (packed16) = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1) is even for idx % 4 == 0, + // so qs_w (packed32) = qs_i / 2 = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2). + const uint qs_w = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2); + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const uint qs4 = (qsw >> qsshift) & 0x03030303u; + const u8vec4 qi = unpack8(qs4); + + const uint scales = bl.block.scales[scalesi]; + const float16_t d_sub = dm.x * float16_t(scales & 0xF); + const float16_t m_sub = dm.y * float16_t(scales >> 4); + return f16vec4(vec4(qi) * vec4(float(d_sub)) - vec4(float(m_sub))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { block_q3_K block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K_packed16 { + block_q3_K_packed16 block; +}; + float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const uint idx = coordInBlock[1]; @@ -179,6 +315,47 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ3_K_v(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ3_K_packed16 bl16 = decodeBufQ3_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint n = idx >> 7; // 0,1 + const uint is = idx >> 4; // 0..15 + const uint halfsplit = (idx & 0x60) >> 5; // 0,1,2,3 + const uint qsshift = halfsplit << 1; // 0,2,4,6 + const uint hbit = (n << 2) + halfsplit; // 0..7 (bit position in hmask byte) + + uint32_t scaleidx0 = (is < 8) ? is : (is - 8); + uint32_t scaleidx0shift = (is < 8) ? 0u : 4u; + uint32_t scaleidx1 = is + 8 - (is / 4) * 4; + uint32_t scaleidx1shift = (is / 4) * 2; + + const int8_t us = int8_t( + ((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | + (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + const float16_t dl = bl.block.d * float16_t(int(us) - 32); + + // For idx % 4 == 0: (idx & 0x1F) == (idx & 0x1C) is a multiple of 4. + const uint qsi = (n << 5) + (idx & 0x1Cu); + const uint hmi = (idx & 0x1Cu); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qsw / hmw holds the data for element idx+j. + const uint qsw = uint32_t(bl16.block.qs[qsi >> 1]) + | (uint32_t(bl16.block.qs[(qsi >> 1) + 1u]) << 16); + const uint hmw = uint32_t(bl16.block.hmask[hmi >> 1]) + | (uint32_t(bl16.block.hmask[(hmi >> 1) + 1u]) << 16); + + // qsshift in {0,2,4,6} and hbit in {0..7}: per-byte masks isolate the wanted bits + // with no inter-byte leakage. + const uint ql4 = (qsw >> qsshift) & 0x03030303u; + const uint qh4 = (hmw >> hbit) & 0x01010101u; + + const ivec4 q = ivec4(unpack8(ql4 | (qh4 << 2))) - ivec4(4); + return f16vec4(vec4(q) * vec4(float(dl))); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { block_q4_K block; }; @@ -187,6 +364,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed32 { + block_q4_K_packed32 block; +}; + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { block_q4_K_packed128 block; }; @@ -334,6 +515,55 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ4_K_v(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed32 bl32 = decodeBufQ4_K_packed32(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q4k[0]; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif + + // idx in [0,256); vector decode uses idx a multiple of 4. packed32 word index: + // (qs_i >> 1) == (idx >> 6) * 8 + ((idx & 0x1E) >> 2). sh is 0 or 4 only, so a + // single (w >> sh) & 0x0F0F0F0F isolates all four nibbles without inter-byte leakage. + const uint sh = (idx & 0x20u) >> 3u; + const uint w = uint32_t(bl32.block.qs[(idx >> 6) * 8u + ((idx & 0x1Eu) >> 2)]); + const u8vec4 q = unpack8((w >> sh) & 0x0F0F0F0Fu); + + return f16vec4(vec4(d) * vec4(q) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { block_q5_K block; }; @@ -346,6 +576,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 block_q5_K_packed128 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed32 { + block_q5_K_packed32 block; +}; + float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); @@ -399,6 +633,58 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ5_K_v(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed32 bl32 = decodeBufQ5_K_packed32(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + const uint is = idx >> 5; + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); +#endif + + // sh is 0 or 4; mask 0x0F0F0F0F covers the four nibbles regardless (no inter-byte leakage). + const uint sh = (idx & 0x20u) >> 3u; + const uint qs_w = (idx >> 6) * 8u + ((idx & 0x1Eu) >> 2); + const uint qh_w = (idx & 0x1Eu) >> 2; + + const uint ql4 = (uint32_t(bl32.block.qs[qs_w]) >> sh) & 0x0F0F0F0Fu; + // qh stores bit `is` per element across 4 consecutive bytes; one shift+mask handles all 4. + const uint qh4 = ((uint32_t(bl32.block.qh[qh_w]) >> is) & 0x01010101u) << 4u; + + const u8vec4 qi = unpack8(ql4 | qh4); + return f16vec4(vec4(qi) * vec4(d) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { block_q6_K block; }; @@ -431,6 +717,35 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ6_K_v(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = idx >> 4; + const uint sh = b * 4; // 0 or 4 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + const uint ql_i = ((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1); + const uint qh_i = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qlw / qhw holds the data for element idx+j. + const uint qlw = uint32_t(bl16.block.ql[ql_i ]) | (uint32_t(bl16.block.ql[ql_i + 1]) << 16); + const uint qhw = uint32_t(bl16.block.qh[qh_i ]) | (uint32_t(bl16.block.qh[qh_i + 1]) << 16); + + // sh in {0,4} and qhshift in {0,2,4,6}: per-byte masks 0x0F / 0x03 keep only the + // wanted bits with no inter-byte leakage; place qh's 2 bits at nibble high position. + const uint ql4 = (qlw >> sh) & 0x0F0F0F0Fu; + const uint qh4 = ((qhw >> qhshift) & 0x03030303u) << 4u; + + const ivec4 qi = ivec4(unpack8(ql4 | qh4)); + return f16vec4((vec4(qi) - vec4(32.0f)) * vec4(float(dscale))); +} + #if defined(DATA_A_IQ1_S) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { block_iq1_s block; @@ -453,6 +768,29 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_S_v(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const int i8b = int(idx & 4); // 0 or 4 + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = float(d) * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * dl); +} #endif #if defined(DATA_A_IQ1_M) @@ -485,6 +823,33 @@ float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_M_v(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); + const uint idx = coordInBlock[1]; + + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = idx >> 3; + const uint ib16 = idx >> 4; + const int i8b = int(idx & 4); // 0 or 4 -- i8 base for the V=4 group + + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2.0 * float(bitfieldExtract(sc, 3 * int(ib16 & 3), 3)) + 1.0; + const float delta = ((qh & 8u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7u) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * (float(d) * dl)); +} #endif #if defined(DATA_A_IQ2_XXS) @@ -520,6 +885,33 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XXS_v(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = (idx & 0x18) >> 3; + const uint iqs = 8 * ib32 + ib8; + + const uint qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_XS) @@ -548,6 +940,31 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XS_v(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; + const uint sshift = (idx & 0x10) >> 2; + const uint iqs = idx >> 3; + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_S) @@ -576,6 +993,32 @@ float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ2_S_v(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sb = uint(bl.block.qs[QUANT_K / 8 + ib8]) >> (idx & 0x6u); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + + const uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_XXS) @@ -609,6 +1052,32 @@ float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCo const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_XXS_v(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3); + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2(bl16.block.qs[is/2+0], bl16.block.qs[is/2+1])); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + + const uint sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sb = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6u); + + const uint grid = iq3xxs_grid[qs]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_S) @@ -635,6 +1104,30 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_S_v(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint iqh = idx >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const uint sb = uint(bl.block.signs[iqs / 2]) >> (idx & 0x6u); + const uint scale = bl.block.scales[iqs / 16]; + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + + const uint grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ4_XS) @@ -642,6 +1135,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_xs block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufIQ4_XS_packed32 { + block_iq4_xs_packed32 block; +}; + float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -657,6 +1154,30 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); return ret; } + +f16vec4 dequantFuncIQ4_XS_v(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_XS_packed32 bl32 = decodeBufIQ4_XS_packed32(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; // 0..7 + const uint sl = (bl32.block.scales_l >> (4 * ib32)) & 0xF; + const uint sh = (uint(bl32.block.scales_h) >> (2 * ib32)) & 0x3; + const uint qshift = (idx & 0x10) >> 2; // {0, 4} + const uint qs_w = 4 * ib32 + ((idx & 0xC) >> 2); // iqs / 4, in [0,32) + + const float16_t dl = d * float16_t(int(sl | (sh << 4)) - 32); + + const uint qsw = bl32.block.qs[qs_w]; + const u8vec4 qv = unpack8((qsw >> qshift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_iq4nl[qv.x]), + float(kvalues_iq4nl[qv.y]), + float(kvalues_iq4nl[qv.z]), + float(kvalues_iq4nl[qv.w])) * float(dl); + return f16vec4(ret); +} #endif #if defined(DATA_A_IQ4_NL) @@ -664,6 +1185,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_nl block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL_packed16 { + block_iq4_nl_packed16 block; +}; + float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -676,6 +1201,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; return ret; } + +f16vec4 dequantFuncIQ4_NL_v(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_NL_packed16 bl16 = decodeBufIQ4_NL_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4( + float(d) * float(kvalues_iq4nl[q.x]), + float(d) * float(kvalues_iq4nl[q.y]), + float(d) * float(kvalues_iq4nl[q.z]), + float(d) * float(kvalues_iq4nl[q.w])); +} #endif #if defined(DATA_A_MXFP4) @@ -695,6 +1238,26 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; } + +f16vec4 dequantFuncMXFP4_v(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uvec4 qv = uvec4( + uint(bl.block.qs[iqs]), + uint(bl.block.qs[iqs + 1u]), + uint(bl.block.qs[iqs + 2u]), + uint(bl.block.qs[iqs + 3u])); + qv = (qv >> shift) & 0xFu; + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} #endif #if defined(DATA_A_NVFP4) @@ -702,6 +1265,10 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVF block_nvfp4 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4_packed32 { + block_nvfp4_packed32 block; +}; + float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const uint idx = coordInBlock[1]; @@ -713,56 +1280,97 @@ float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords qs = (qs >> shift) & 0xF; return float16_t(kvalues_mxfp4[qs] * d * 0.5); } + +f16vec4 dequantFuncNVFP4_v(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufNVFP4_packed32 bl32 = decodeBufNVFP4_packed32(bl); + const uint idx = coordInBlock[1]; + const uint sub = idx >> 4; + const uint qs_w = ((idx & 0x30) >> 3) + ((idx & 0x4u) >> 2); // iqs / 4, in [0,8) + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 qv = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} #endif #if defined(DATA_A_Q1_0) #define dequantFuncA dequantFuncQ1_0 +#define dequantFuncA_v dequantFuncQ1_0_v #elif defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 +#define dequantFuncA_v dequantFuncQ4_0_v #elif defined(DATA_A_Q4_1) #define dequantFuncA dequantFuncQ4_1 +#define dequantFuncA_v dequantFuncQ4_1_v #elif defined(DATA_A_Q5_0) #define dequantFuncA dequantFuncQ5_0 +#define dequantFuncA_v dequantFuncQ5_0_v #elif defined(DATA_A_Q5_1) #define dequantFuncA dequantFuncQ5_1 +#define dequantFuncA_v dequantFuncQ5_1_v #elif defined(DATA_A_Q8_0) #define dequantFuncA dequantFuncQ8_0 +#define dequantFuncA_v dequantFuncQ8_0_v #elif defined(DATA_A_Q2_K) #define dequantFuncA dequantFuncQ2_K +#define dequantFuncA_v dequantFuncQ2_K_v #elif defined(DATA_A_Q3_K) #define dequantFuncA dequantFuncQ3_K +#define dequantFuncA_v dequantFuncQ3_K_v #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define dequantFuncA_v dequantFuncQ4_K_v #define fetch_scales fetch_scalesQ4_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define dequantFuncA_v dequantFuncQ5_K_v #define fetch_scales fetch_scalesQ5_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#define dequantFuncA_v dequantFuncQ6_K_v #elif defined(DATA_A_IQ1_S) #define dequantFuncA dequantFuncIQ1_S +#define dequantFuncA_v dequantFuncIQ1_S_v #elif defined(DATA_A_IQ1_M) #define dequantFuncA dequantFuncIQ1_M +#define dequantFuncA_v dequantFuncIQ1_M_v #elif defined(DATA_A_IQ2_XXS) #define dequantFuncA dequantFuncIQ2_XXS +#define dequantFuncA_v dequantFuncIQ2_XXS_v #elif defined(DATA_A_IQ2_XS) #define dequantFuncA dequantFuncIQ2_XS +#define dequantFuncA_v dequantFuncIQ2_XS_v #elif defined(DATA_A_IQ2_S) #define dequantFuncA dequantFuncIQ2_S +#define dequantFuncA_v dequantFuncIQ2_S_v #elif defined(DATA_A_IQ3_XXS) #define dequantFuncA dequantFuncIQ3_XXS +#define dequantFuncA_v dequantFuncIQ3_XXS_v #elif defined(DATA_A_IQ3_S) #define dequantFuncA dequantFuncIQ3_S +#define dequantFuncA_v dequantFuncIQ3_S_v #elif defined(DATA_A_IQ4_XS) #define dequantFuncA dequantFuncIQ4_XS +#define dequantFuncA_v dequantFuncIQ4_XS_v #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#define dequantFuncA_v dequantFuncIQ4_NL_v #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#define dequantFuncA_v dequantFuncMXFP4_v #elif defined(DATA_A_NVFP4) #define dequantFuncA dequantFuncNVFP4 +#define dequantFuncA_v dequantFuncNVFP4_v #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl new file mode 100644 index 00000000000..c474bfe09ce --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl @@ -0,0 +1,27 @@ +#ifdef DOT2_F16 +#extension GL_EXT_spirv_intrinsics : require + +spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"], + capabilities = [6912], id = 6916) +float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc); + +ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc)))); +} + +ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc))); +} + +#else + +ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), + fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc)))); +} + +ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc)); +} + +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp new file mode 100644 index 00000000000..65e9c678401 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix_decode_vector : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6ac095489b3..91fb07c93e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -21,6 +21,7 @@ #extension GL_KHR_shader_subgroup_vote : enable #include "types.glsl" +#include "dot_product_funcs.glsl" #include "flash_attn_base.glsl" #include "flash_attn_dequant.glsl" @@ -318,7 +319,7 @@ void main() { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); + Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]); } } } @@ -341,7 +342,7 @@ void main() { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); + Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9a7957da97b..66dcf610219 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -97,8 +97,17 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define FA_TYPE_Q5_0 6u #define FA_TYPE_Q5_1 7u #define FA_TYPE_Q8_0 8u +#define FA_TYPE_BF16 30u #define FA_TYPE_Q1_0 41u +#if defined(BFLOAT16) +#define O_TYPE float +#define O_TYPEV4 vec4 +#else +#define O_TYPE FLOAT_TYPE +#define O_TYPEV4 FLOAT_TYPEV4 +#endif + // Number of matrix elements per buffer block, derived from the K/V type spec // constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 // and bypasses the dequant path entirely. Quants follow their ggml block sizes. @@ -111,6 +120,7 @@ uint fa_block_elems(uint ty) { case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_BF16: return 1u; case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere default: return 1u; } @@ -248,7 +258,7 @@ const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +void gqaStore(const in uint32_t r, const in uint32_t c, const in O_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV / 4 + c; data_ov4[o_offset + offset] = D_TYPEV4(elems); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index bffcc095be3..23ae3833e52 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -6,6 +6,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -14,7 +18,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "flash_attn_dequant.glsl" +#endif // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -27,32 +33,32 @@ const uint32_t cols_per_thread = Bc / cols_per_iter; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; -layout (binding = 1) readonly buffer K {float16_t data_k[];}; -layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; -layout (binding = 2) readonly buffer V {float16_t data_v[];}; -layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 1) readonly buffer K {FLOAT_TYPE data_k[];}; +layout (binding = 1) readonly buffer KV4 {FLOAT_TYPEV4 data_kv4[];}; +layout (binding = 2) readonly buffer V {FLOAT_TYPE data_v[];}; +layout (binding = 2) readonly buffer VV4 {FLOAT_TYPEV4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; shared float tmpsh[row_split]; -const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 Qf[Br * qstride]; +const uint32_t qstride = HSK_pad / 4 + 2; +shared FLOAT_TYPEV4 Qf[Br * qstride]; const uint psh_stride = Br / 4 + 2; -shared f16vec4 Psh[Bc * psh_stride]; +shared FLOAT_TYPEV4 Psh[Bc * psh_stride]; // Avoid padding for hsk==256 to make it fit in 48KB shmem. const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; -const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; +shared FLOAT_TYPEV4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; const uint32_t osh_stride = row_split * MatBr / 4; -shared f16vec4 pvsh[MatBc * osh_stride]; +shared O_TYPEV4 pvsh[MatBc * osh_stride]; shared ACC_TYPE slope[Br]; @@ -76,7 +82,7 @@ void main() { if ((HSK % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { - Qf[i + tid] = f16vec4(0); + Qf[i + tid] = FLOAT_TYPEV4(0); } } barrier(); @@ -89,15 +95,15 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - f16vec4 Of[rows_per_thread][d_per_thread]; + O_TYPEV4 Of[rows_per_thread][d_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { - Of[r][d] = f16vec4(0.0); + Of[r][d] = O_TYPEV4(0.0); } } @@ -222,15 +228,18 @@ void main() { uint32_t d = (idx + tid) % (HSK_pad / 4); uint32_t c = (idx + tid) / (HSK_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); } } @@ -244,16 +253,16 @@ void main() { // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; - coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem - // If not, f16 K is loaded directly from global memory if aligned, otherwise + // If not, K is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If K is not type f16, then it is always staged for dequantization. + // If K is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always need to dequant into kvsh; for f16 we can load + // For quants we always need to dequant into kvsh; for f16/bf16 we can load // directly from global memory when alignment / bounds allow it. const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; if (stage_k) { @@ -262,15 +271,18 @@ void main() { uint32_t col_vec = (idx + tid) % (MatBr / 4); uint32_t row = (idx + tid) / (MatBr / 4); if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); } } @@ -357,7 +369,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local]; + Of[r][d_local] = O_TYPE(eMf[r]) * Of[r][d_local]; } } @@ -368,10 +380,10 @@ void main() { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { const uint row = tile_row(r); if (KV_bounds_check && j * Bc + col >= KV) { - Psh[col * psh_stride + row / 4] = f16vec4(0.0f); + Psh[col * psh_stride + row / 4] = FLOAT_TYPEV4(0.0f); } else { const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); - const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + const FLOAT_TYPEV4 Pf = FLOAT_TYPEV4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { Lf[r + vec_idx] += Pf[vec_idx]; } @@ -385,15 +397,18 @@ void main() { uint32_t d = (idx + tid) % (HSV_pad / 4); uint32_t c = (idx + tid) / (HSV_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { - f16vec4 V_Tf = f16vec4(0); + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; uint ib = coord / BLOCK_SIZE_V; uint iqs = (coord % BLOCK_SIZE_V); V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } else +#endif + { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); } } @@ -409,7 +424,7 @@ void main() { [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; - coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); + coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); // Preload V tiles for [Bc, 16 * num subgroups] const uint v_rows = Bc; @@ -417,11 +432,11 @@ void main() { const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. - // If not, f16 V is loaded directly from global memory if aligned, otherwise + // If not, V is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If V is not type f16, then it is always staged for dequantization. + // If V is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always preload via kvsh. For f16 we only preload when + // For quants we always preload via kvsh. For f16/bf16 we only preload when // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). const bool stage_v = USE_DECODE_V || KV_bounds_check; if (stage_v) { @@ -438,13 +453,16 @@ void main() { const uint iqs = coord % BLOCK_SIZE_V; if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { + } else +#endif + { kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; } } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = FLOAT_TYPEV4(0.0f); } } } @@ -459,7 +477,7 @@ void main() { if (SHMEM_STAGING == 0) { if (!USE_DECODE_V && !KV_bounds_check) { - // F16 values can be loaded directly from global memory + // F16/BF16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); @@ -573,7 +591,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; - Of[r][d_local] *= float16_t(ms); + Of[r][d_local] *= O_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -591,7 +609,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] *= float16_t(Lfrcp[r]); + Of[r][d_local] *= O_TYPE(Lfrcp[r]); #if defined(FLOAT_TYPE_MAX) Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 141bb870883..b9c03fe499d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -8,9 +8,16 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -18,7 +25,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "dequant_funcs_cm2.glsl" +#endif // buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { @@ -28,6 +37,7 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; +#if !defined(BFLOAT16) float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); @@ -54,6 +64,42 @@ float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const } } +// V=4 vector decode for K/V; dispatches to per-format _v decoders. +f16vec4 faDecodeKVector(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +#ifdef GL_NV_cooperative_matrix_decode_vector +#define FADECODEK , faDecodeK, faDecodeKVector +#define FADECODEV , faDecodeV, faDecodeVVector +#else +#define FADECODEK , faDecodeK +#define FADECODEV , faDecodeV +#endif +#endif + layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; @@ -157,15 +203,15 @@ void main() { tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q; - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); - Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - Qf16 *= float16_t(p.scale); + Q *= Q_TYPE(p.scale); + Qf16 = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); + coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; @@ -253,16 +299,20 @@ void main() { coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. +#if defined(BFLOAT16) + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); +#else const bool k_use_decode = (bs_k > 1u); if (k_use_decode) { - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK); } else { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); } +#endif S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -313,22 +363,26 @@ void main() { coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); } - coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); // compute rowsum by multiplying by matrix of all ones. - coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; +#if defined(BFLOAT16) + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); +#else const bool v_use_decode = (bs_v > 1u); if (v_use_decode) { - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV); } else { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); } +#endif L = eM*L + rowsum; @@ -340,7 +394,7 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag); + O *= coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag); O = coopMatMulAdd(P_A, V, O); } @@ -389,7 +443,7 @@ void main() { if (sink > Mr[i]) { ms = exp(Mr[i] - sink); - O[i] *= float16_t(ms); + O[i] *= O_TYPE(ms); } else { vs = exp(sink - Mr[i]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl index 02106f33cbe..8704479d960 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -28,6 +28,9 @@ layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; +layout (binding = 1) readonly buffer K_PACKED_BF16 { u16vec4 data[]; } k_packed_bf16; +layout (binding = 2) readonly buffer V_PACKED_BF16 { u16vec4 data[]; } v_packed_bf16; + // Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 // views, used by the MMQ K-side hot path for fast 4-uint loads. layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; @@ -99,6 +102,9 @@ layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 dat return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ } +#define FA_DEQUANT4_BF16(BUF) \ + return FLOAT_TYPEV4(bf16_to_fp32(uvec4(BUF.data[(a_offset + ib) / 4]))); + FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { switch (FaTypeK) { @@ -108,6 +114,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(k_packed_bf16) } } else { switch (FaTypeV) { @@ -117,6 +124,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(v_packed_bf16) } } return FLOAT_TYPEV4(0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp new file mode 100644 index 00000000000..a2069964adb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -0,0 +1,115 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#ifndef FWHT_SHMEM +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(constant_id = 1) const uint N = 128; + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(push_constant) uniform parameter +{ + uint n_rows; + uint src_offset; + uint dst_offset; + float scale; +}; + +layout(binding = 0, std430) readonly buffer A { float data_a[]; }; +layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; + +const uint EL_W = N / BLOCK_SIZE; + +#ifdef FWHT_SHMEM +shared float shmem[4 * N]; +#endif + +void main() { +#ifdef FWHT_SHMEM + const uint tid = gl_LocalInvocationID.x; + const uint shmem_base = gl_LocalInvocationID.y * N; + const uint row_id = gl_LocalInvocationID.y; +#else + const uint tid = gl_SubgroupInvocationID; + const uint row_id = gl_SubgroupID; +#endif + + for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y; + base_row < n_rows; + base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row = base_row + row_id; + const uint row_offset = row * N; + +#ifndef FWHT_SHMEM + if (row >= n_rows) { + continue; + } +#endif + + float reg[EL_W]; + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0; + } + +#ifdef FWHT_SHMEM + [[unroll]] + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i]; + } + barrier(); + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)]; + reg[j] = (tid & h) == 0 ? val + other : other - val; + } + barrier(); + } +#else + [[unroll]] + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float val2 = subgroupShuffleXor(val, h); + reg[j] = (tid & h) == 0 ? val + val2 : val2 - val; + } + } +#endif + + [[unroll]] + for (uint h = BLOCK_SIZE; h < N; h <<= 1) { + const uint step = h / BLOCK_SIZE; + [[unroll]] + for (uint j = 0; j < EL_W; j += 2 * step) { + [[unroll]] + for (uint k = 0; k < step; ++k) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + +#ifdef FWHT_SHMEM + if (row < n_rows) { +#endif + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i]; + } +#ifdef FWHT_SHMEM + } + barrier(); +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 33c3202dbb7..0e384330b9b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -102,8 +102,8 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. - const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_in_base = (seq_id * H + head_id) * state_size; // output state layout per slot: same per-(seq,head) offset as the single-slot case. const uint state_out_base = (seq_id * H + head_id) * state_size; const uint state_size_per_snap = state_size * H * n_seqs; @@ -113,9 +113,8 @@ void main() { s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = int(n_tokens) - int(K); + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -172,7 +171,7 @@ void main() { attn_off += S_V * H; if (K > 1u) { - const int target_slot = int(t) - shift; + const int target_slot = int(n_tokens) - 1 - int(t); if (target_slot >= 0 && target_slot < int(K)) { const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 2271be4021b..5a9d0e778fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else -#define K_PER_ITER 2 +#define K_PER_ITER 4 #endif uint a_offset, b_offset, d_offset, y_offset; +vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) { + // Check if the latter elements are OOB, and don't fetch B or accumulate it. + OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols); + OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols); + OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols); + + if (!OOB_w) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3])); + } else if (!OOB_z) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + 0); + } else if (!OOB_y) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + 0, 0); + } else { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + 0, 0, 0); + } +} + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { @@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #if K_PER_ITER == 8 #if QUANT_R == 2 + // Note that we end up fetching bogus elements here, but its fine as they'll be + // within an accessible block. const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); @@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); #endif #else - // Check if the second of the pair of elements is OOB, and don't fetch B or - // accumulate it. We still fetch a pair of elements for A, which is fine for - // quantized formats since they'll be within the same block. We should - // probably skip fetching the second element for F16/F32, but as of now we - // still do. - const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); - - FLOAT_TYPE b0 = 0, b1 = 0; - b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); - if (!OOB) { - b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); - } + bool OOB_y; + bool OOB_z; + bool OOB_w; + + const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w); #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const temp[j][n] += rowtmp; #else - const vec2 v = dequantize(ib, iqs, a_offset); - - // matrix multiplication - temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); - if (!OOB) { - temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + if (!OOB_w) { + const vec4 v = dequantize4(ib, iqs, a_offset); + temp[j][n] += dot(v, b); + } else if (!OOB_z) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset); + const vec3 v = vec3(v0.x, v0.y, v1); + const vec3 b0 = vec3(b.x, b.y, b.z); + temp[j][n] += dot(v, b0); + } else if (!OOB_y) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 b0 = vec2(b.x, b.y); + temp[j][n] += dot(v0, b0); + } else { + const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset); + temp[j][n] = fma(v, b.x, temp[j][n]); } #endif } } } +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = 0; // quant index + const uint iybs = col; // y block start index + + const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + + const vec4 v = dequantize4_2aligned(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] += dot(v, b); + } + } +} +#endif + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); + const bool is_aligned_nonquant = + p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 && + p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 && + K_PER_ITER == 4; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; @@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { int unroll_count = 4; uint unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 + uint i = 0; + +#if K_PER_ITER == 4 // If the K dimension is odd, we need lastiter==true on the last iteration // so OOB is computed correctly. Skip some unrolling to make that happen. - if ((p.ncols & 1) != 0 && + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } + if (is_aligned_nonquant) { + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { #endif - - uint i = 0; while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && +#if K_PER_ITER == 4 + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } -#endif + if (is_aligned_nonquant) { + while (i < unrolled_iters && is_aligned_nonquant) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { +#endif while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif + +#if K_PER_ITER == 4 + if (is_aligned_nonquant) { + while (i < num_iters) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } else { +#endif while (i < num_iters) { iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); i++; } +#if K_PER_ITER == 4 + } +#endif reduce_result(temp, d_offset, first_row, num_rows, tid); } @@ -164,6 +259,6 @@ void main() { if (first_row >= p.stride_d) { return; } - compute_outputs(first_row, p.stride_d - first_row); + compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row)); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 6fe3e2dc043..fd84c3c91d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -4,6 +4,7 @@ #extension GL_EXT_integer_dot_product : require #define MMQ +#define NEEDS_IQ1S_GRID_GPU #define B_TYPE block_q8_1_x4 #include "mul_mat_vec_base.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index bc580aeeb83..73cf9c79955 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -212,28 +212,40 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qs_shift = ((iqs_k % 32) / 8) * 2; const uint hm_shift = iqs_k / 8; + const uvec4 qs = uvec4( uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 7]) << 16)); + + const uvec4 hmask = uvec4( uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 ]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 7]) << 16)); + // bitwise OR to add 4 if hmask is set, subtract later - const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); + const uint vals0 = (( qs.x >> qs_shift) & 0x03030303) | + (((hmask.x >> hm_shift) & 0x01010101) << 2); + const uint vals1 = (( qs.y >> qs_shift) & 0x03030303) | + (((hmask.y >> hm_shift) & 0x01010101) << 2); + const uint vals2 = (( qs.z >> qs_shift) & 0x03030303) | + (((hmask.z >> hm_shift) & 0x01010101) << 2); + const uint vals3 = (( qs.w >> qs_shift) & 0x03030303) | + (((hmask.w >> hm_shift) & 0x01010101) << 2); + + // Subtract 4 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x04040404) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { @@ -343,27 +355,39 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qh_idx = (iqs_k / 32) * 8 + iqs; const uint qh_shift = ((iqs_k % 32) / 8) * 2; - const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); + const uvec4 ql = uvec4( uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 7]) << 16)); + + const uvec4 qh = uvec4( uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 7]) << 16)); + + const uint vals0 = (( ql.x >> ql_shift) & 0x0F0F0F0F) | + (((qh.x >> qh_shift) & 0x03030303) << 4); + const uint vals1 = (( ql.y >> ql_shift) & 0x0F0F0F0F) | + (((qh.y >> qh_shift) & 0x03030303) << 4); + const uint vals2 = (( ql.z >> ql_shift) & 0x0F0F0F0F) | + (((qh.z >> qh_shift) & 0x03030303) << 4); + const uint vals3 = (( ql.w >> ql_shift) & 0x0F0F0F0F) | + (((qh.w >> qh_shift) & 0x03030303) << 4); + + // Subtract 32 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x20202020) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 89346e48e06..f39410d74f0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -29,6 +29,7 @@ #endif #include "types.glsl" +#include "dot_product_funcs.glsl" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 @@ -329,15 +330,8 @@ void main() { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; - #if defined(DATA_A_F32) || defined(DATA_A_F16) - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x)))); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y)))); - #else - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); - #endif + sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x); + sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 497a18ff8a7..2656fe1c3e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -11,6 +11,9 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GGML_VULKAN_COOPMAT2_DECODE_VECTOR +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -69,12 +72,17 @@ layout (push_constant) uniform parameter layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +layout (binding = 1) readonly buffer B4 {B_TYPEV4 data_b_v4[];}; +#endif #if QUANT_K > 1 -#define DECODEFUNCA , dequantFuncA - #include "dequant_funcs_cm2.glsl" - +#if defined(dequantFuncA_v) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +#define DECODEFUNCA , dequantFuncA, dequantFuncA_v +#else +#define DECODEFUNCA , dequantFuncA +#endif #else #define DECODEFUNCA #endif @@ -111,11 +119,33 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i const uint row_i = blockCoords[0]; const u16vec4 row_idx = row_ids[row_i]; - B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + // The decode-vector path gives B a K-dimension tensor-layout block size of BK. + const uint k = blockCoords[1] * BK + coordInBlock[1]; +#else + const uint k = blockCoords[1]; +#endif + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k]; return ret; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +B_TYPEV4 decodeFuncB_v(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + const u16vec4 row_idx = row_ids[row_i]; + const uint k = blockCoords[1] * BK + coordInBlock[1]; + const uint base = row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k; + + return data_b_v4[base >> 2]; +} +#define DECODEFUNCB , decodeFuncB, decodeFuncB_v +#else +#define DECODEFUNCB , decodeFuncB +#endif + D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) { uint dr = ir * BM + r; @@ -285,6 +315,9 @@ void main() { tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); #endif +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + tensorLayoutB = setTensorLayoutBlockSizeNV(tensorLayoutB, 1, BK); +#endif // Use end_k rather than p.K as the dimension because that's what // we need to bound check against when using split_k. @@ -497,7 +530,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -505,7 +538,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -541,7 +574,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -549,7 +582,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -586,7 +619,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif @@ -598,7 +631,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 4bcd97756fd..8c6b20c6889 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -31,6 +31,7 @@ #else #define A_TYPE float16_t #endif +#define A_TYPE_PACKED32 f16vec2 #endif #if defined(DATA_A_BF16) @@ -44,6 +45,7 @@ #else #define A_TYPE uint16_t #endif +#define A_TYPE_PACKED32 uint32_t #endif #define QUANT_K_Q4_0 32 @@ -596,9 +598,10 @@ const uint[1024] iq1s_grid_const = { 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 }; +#if defined(NEEDS_IQ1S_GRID_GPU) // Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit // and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F -// and 0xF0F0F0F0). +// and 0xF0F0F0F0). This is only used by the q8_1/int-dot vector path. const uint32_t[2048] iq1s_grid_gpu_const = { 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, @@ -857,9 +860,12 @@ const uint32_t[2048] iq1s_grid_gpu_const = { 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; +#endif shared uint16_t iq1s_grid[2048]; +#if defined(NEEDS_IQ1S_GRID_GPU) shared uint32_t iq1s_grid_gpu[2048]; +#endif #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) @@ -873,12 +879,14 @@ void init_iq_shmem(uvec3 wgsize) iq1s_grid[2*idx+1] = g.y; } } +#if defined(NEEDS_IQ1S_GRID_GPU) [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) { uint idx = i + gl_LocalInvocationIndex.x; if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) { iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx]; } } +#endif barrier(); } #endif @@ -1722,11 +1730,18 @@ struct block_nvfp4 uint8_t qs[QUANT_K_NVFP4 / 2]; }; +struct block_nvfp4_packed32 +{ + uint32_t d[QUANT_K_NVFP4 / 16 / 4]; + uint32_t qs[QUANT_K_NVFP4 / 2 / 4]; +}; + #if defined(DATA_A_NVFP4) #define QUANT_K QUANT_K_NVFP4 #define QUANT_R QUANT_R_NVFP4 #define QUANT_AUXF 1 #define A_TYPE block_nvfp4 +#define A_TYPE_PACKED32 block_nvfp4_packed32 #endif #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a1d735150fd..7bcb1460814 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -336,7 +336,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734 // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860 - if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) { + // disable spirv-opt for dot2 shaders (spirv-opt doesn't recognize SPV_VALVE_mixed_float_dot_product capability) + if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos && name.find("_dot2") == std::string::npos) { cmd.push_back("-O"); } @@ -427,10 +428,11 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s generate_dep_file = false; } -void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc, bool dot2 = false) { std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + std::string dot2_sfx = dot2 ? "_dot2" : ""; std::map<std::string, std::string> base_dict; std::string shader_name = "matmul"; @@ -457,6 +459,15 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (coopmat) { base_dict["COOPMAT"] = "1"; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (coopmat2) { + base_dict["GGML_VULKAN_COOPMAT2_DECODE_VECTOR"] = "1"; + } +#endif + + if (dot2) { + base_dict["DOT2_F16"] = "1"; + } const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; @@ -523,11 +534,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -548,8 +559,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + if (!dot2) { + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } } } @@ -579,18 +592,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - // Integer dot mmq performs better with f32 accumulators - if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { + // Integer dot mmq performs better with f32 accumulators (different shader, skip for dot2) + if (!f16acc && !coopmat && !coopmat2 && !dot2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif @@ -608,6 +621,10 @@ void process_shaders() { matmul_shaders(true, matmul_id_type, false, false, false); matmul_shaders(true, matmul_id_type, false, false, true); + // dot2 variants (scalar fp16 only) + matmul_shaders(true, matmul_id_type, false, false, false, true); + matmul_shaders(true, matmul_id_type, false, false, true, true); + if (matmul_id_type != MatMulIdType::DEFAULT) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) // Coopmat, fp32acc and fp16acc @@ -655,6 +672,12 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16", "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); + + if (fp16) { + string_to_spv("flash_attn_f32_f16_dot2", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DOT2_F16", "1"}}), fp16, false, false, f16acc); + } + #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) string_to_spv("flash_attn_f32_f16", "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); @@ -662,6 +685,28 @@ void process_shaders() { } } + const std::map<std::string, std::string> fa_bf16_dict = { + {"FLOAT_TYPE", "bfloat16_t"}, + {"FLOAT_TYPEV2", "bf16vec2"}, + {"FLOAT_TYPEV4", "bf16vec4"}, + {"ACC_TYPE", "float"}, + {"ACC_TYPEV2", "vec2"}, + {"ACC_TYPEV4", "vec4"}, + {"BFLOAT16", "1"}, + }; + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm1.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), + true, true, false, false); +#endif + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm2.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), + true, false, true, false); +#endif + std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) { @@ -798,9 +843,11 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -932,6 +979,8 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("fwht_f32", "fwht.comp", {}); + string_to_spv("fwht_shmem_f32", "fwht.comp", {{"FWHT_SHMEM", "1"}}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -984,8 +1033,16 @@ void process_shaders() { string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (unroll) { - defines["COOPMAT2"] = "1"; - string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + auto cm2_defines = defines; + cm2_defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm2_defines, true, false, true); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (unroll) { + auto cm1_defines = defines; + cm1_defines["COOPMAT"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm1_defines, true, true, false); } #endif } diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 3ccce58aa39..1503a1ef8ba 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR}) message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}") -# Find all WGSL files -file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl") +# Find all WGSL sources +file(GLOB WGSL_SHADER_FILES + "${SHADER_DIR}/*.wgsl" + "${SHADER_DIR}/*.tmpl" +) # Generate the header using a Python script add_custom_command( diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 4c4eda1cbe5..6f877f15ce9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -18,6 +18,9 @@ #define GGML_WEBGPU_F32_SIZE_BYTES 4 #define GGML_WEBGPU_I32_SIZE_BYTES 4 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u +#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. #define GGML_WEBGPU_KV_SEQ_PAD 256u @@ -52,7 +55,7 @@ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 -// default size for legacy matrix multiplication +// default size for reg-tile matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost @@ -84,15 +87,17 @@ struct ggml_webgpu_shader_lib_context { ggml_tensor * src5; ggml_tensor * dst; - uint32_t max_wg_size; - size_t wg_mem_limit_bytes = 0; - bool supports_subgroups = false; - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m = 0; - uint32_t sg_mat_n = 0; - uint32_t sg_mat_k = 0; - uint32_t min_subgroup_size = 0; - uint32_t max_subgroup_size = 0; + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool supports_subgroups = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; + uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; + std::string vendor; }; struct webgpu_pipeline { @@ -164,9 +169,11 @@ struct ggml_webgpu_set_rows_pipeline_key { int dst_type; int vec4; int i64_idx; + int pair_blocks; bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { - return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx && + pair_blocks == other.pair_blocks; } }; @@ -176,6 +183,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.vec4); ggml_webgpu_hash_combine(seed, key.i64_idx); + ggml_webgpu_hash_combine(seed, key.pair_blocks); return seed; } }; @@ -183,6 +191,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { struct ggml_webgpu_set_rows_shader_decisions { bool vec4; bool i64_idx; + bool pair_blocks; uint32_t wg_size; }; @@ -439,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash { /** Concat **/ struct ggml_webgpu_concat_pipeline_key { - int type; + int type; + bool src_overlap; - bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; } + bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { + return type == other.type && src_overlap == other.src_overlap; + } }; struct ggml_webgpu_concat_pipeline_key_hash { size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } }; @@ -540,16 +553,10 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ -enum ggml_webgpu_flash_attn_path : uint32_t { - GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, - GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, - GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, -}; - -struct ggml_webgpu_flash_attn_pipeline_key { +struct ggml_webgpu_flash_attn_common_pipeline_key { ggml_type q_type; - ggml_type kv_type; + ggml_type k_type; + ggml_type v_type; ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; @@ -558,93 +565,227 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - uint32_t path; + + bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const { + return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type && + dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; + } +}; + +inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, + const ggml_webgpu_flash_attn_common_pipeline_key & key) { + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.k_type); + ggml_webgpu_hash_combine(seed, key.v_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); +} + +struct ggml_webgpu_flash_attn_vec_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + + bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; } +}; + +struct ggml_webgpu_flash_attn_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + bool use_sg_matrix; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { - return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && - kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && path == other.path; + return common == other.common && use_sg_matrix == other.use_sg_matrix; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_type); - ggml_webgpu_hash_combine(seed, key.kv_type); - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.head_dim_qk); - ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.kv_direct); - ggml_webgpu_hash_combine(seed, key.kv_overlap); - ggml_webgpu_hash_combine(seed, key.has_mask); - ggml_webgpu_hash_combine(seed, key.has_sinks); - ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.path); + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + ggml_webgpu_hash_combine(seed, key.use_sg_matrix); return seed; } }; +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; - bool kv_direct = false; - bool kv_overlap = false; + bool use_sg_matrix = false; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; }; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || - key.head_dim_qk != key.head_dim_v) { - return 1u; +inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs; +} + +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { + const uint32_t offset_elems = + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / + ggml_type_size(K->type)); + return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; +} + +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) && + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); +} + +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V, + uint32_t kv_direct_align) { + return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); +} + +inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key( + const ggml_webgpu_shader_lib_context & context, + uint32_t kv_direct_align) { + ggml_webgpu_flash_attn_common_pipeline_key key = {}; + key.q_type = context.src0->type; + key.k_type = context.src1->type; + key.v_type = context.src2->type; + key.dst_type = context.dst->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; +} + +inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines( + const ggml_webgpu_flash_attn_common_pipeline_key & key, + std::string & variant, + uint32_t q_tile, + uint32_t kv_tile, + uint32_t wg_size) { + std::vector<std::string> defines; + + switch (key.k_type) { + case GGML_TYPE_F32: + defines.push_back("K_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("K_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("K_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("K_Q8_0"); + break; + default: + GGML_ABORT("Unsupported K type for flash attention shader"); + } + variant += std::string("_k") + ggml_type_name(key.k_type); + + switch (key.v_type) { + case GGML_TYPE_F32: + defines.push_back("V_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("V_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("V_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("V_Q8_0"); + break; + default: + GGML_ABORT("Unsupported V type for flash attention shader"); } + variant += std::string("_v") + ggml_type_name(key.v_type); - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - return 2u; - case 96: - return 4u; + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; default: - return 1u; + GGML_ABORT("Unsupported Q type for flash attention shader"); } -} + variant += std::string("_q") + ggml_type_name(key.q_type); -inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_decisions & decisions) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - bool kv_direct = false; - if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { - kv_direct_align = context.sg_mat_k; - } - kv_direct = (context.src1->type == GGML_TYPE_F16) && - (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && - (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - } - - ggml_webgpu_flash_attn_pipeline_key key = {}; - key.q_type = context.src0->type; - key.kv_type = context.src1->type; - key.dst_type = context.dst->type; - key.head_dim_qk = (uint32_t) context.src0->ne[0]; - key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = kv_direct; - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = has_mask; - key.has_sinks = has_sinks; - key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.path = decisions.path; - return key; + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) { + defines.push_back("U32_DEQUANT_HELPERS"); + } + + return defines; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -682,29 +823,18 @@ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { } }; -// This is exposed because it's necessary in supports_op +// Note: this will slightly overestimate memory usage for vec path +// since row_max and exp_sum shmem are not needed. inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, - bool kv_direct, - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + bool kv_direct) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; - if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - f32_elems += head_dim_qk; // q_shmem - if (!kv_direct) { - f32_elems += kv_tile * max_head_dim; // kv_shmem - } - f32_elems += head_dim_v; // o_shmem - if (has_mask) { - f32_elems += kv_tile; // mask_shmem - } - f32_elems += kv_tile; // inter_shmem - return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; - } + f32_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { f32_elems += kv_tile * max_head_dim; // kv_shmem @@ -719,25 +849,20 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_pipeline_key & key) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_granularity = std::max(1u, context.sg_mat_n); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; - kv_granularity = 1u; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - q_tile = 1u; - kv_granularity = 8u; - } - const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, + uint32_t q_tile, + uint32_t kv_granularity, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const size_t base_q_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct); if (limit_bytes <= base_q_bytes) { return 0; } - const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); + const size_t one_kv_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct); const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; if (bytes_per_kv == 0) { return 0; @@ -746,135 +871,45 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } -inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( - const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - ggml_webgpu_flash_attn_decisions decisions = {}; - const size_t alignment = std::max<size_t>(1u, storage_offset_alignment); - const auto * K = context.src1; - const auto * V = context.src2; - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - - const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { - constexpr uintptr_t ptr_base_addr = 0x1000u; - const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; - return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs; - }; - - const uint32_t k_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && - (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && - context.src2->ne[0] % kv_vec_head_align == 0; - // Compile with enough invocations to cover the largest reported subgroup. - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && - kv_vec_head_dims_aligned && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); - const bool tile_can_dispatch_all_q_rows = - context.max_subgroup_size > 0 && - context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_subgroup_matrix = - context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && - context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; - const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && - V->type == GGML_TYPE_F16 && f16_vec4_aligned && - (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows && !use_vec; - - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { - return decisions; - } - - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - decisions.kv_direct = key.kv_direct; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - // invalidate if even the smallest kv_tile doesn't fit in shared memory - if (max_kv_tile == 0) { - decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - return decisions; - } - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - decisions.q_tile = 1u; - decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); - decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = context.max_subgroup_size; - if (decisions.kv_direct) { - decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= 8u; - } +inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_kv_tile = + ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); + if (kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= 1u; } - return decisions; } - decisions.q_tile = - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; - decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(64u, max_kv_tile) : - std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(std::max(1u, context.max_wg_size), - std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, - GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) : - std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - if (decisions.kv_tile == 0) { - return decisions; - } + return kv_tile; +} - if (decisions.kv_direct) { - GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n; - } - } - return decisions; +inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, + uint32_t sg_mat_k, + uint32_t sg_mat_n, + const ggml_tensor * Q, + const ggml_tensor * V) { + return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0; } /** Matrix Multiplication **/ -struct ggml_webgpu_legacy_mul_mat_pipeline_key { - ggml_type src0_type; - ggml_type src1_type; - - bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type; - } -}; - -struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { - size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.src0_type); - ggml_webgpu_hash_combine(seed, key.src1_type); - return seed; - } -}; - struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_mmvq == other.use_mmvq; } }; @@ -884,6 +919,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } }; @@ -894,6 +930,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t vec_size; }; +struct ggml_webgpu_quantize_q8_pipeline_key { + ggml_type src0_type; + + bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; } +}; + +struct ggml_webgpu_quantize_q8_pipeline_key_hash { + size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + return seed; + } +}; + struct ggml_webgpu_mul_mat_pipeline_key { ggml_type src0_type; ggml_type src1_type; @@ -1051,6 +1101,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash { } }; +/** MMVQ **/ + +inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, + const ggml_tensor * src1, + bool supports_dot_product, + const std::string & vendor) { + if (src1->ne[1] == 1) { + bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; + if (supports_dp4a && supports_dot_product) { + switch (src1->type) { + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + return src0->ne[0] % 4 == 0; + default: + break; + } + break; + default: + break; + } + } + } + return false; +} + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -1089,6 +1169,10 @@ class ggml_webgpu_shader_lib { concat_pipelines; // type std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash> repeat_pipelines; // type + std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key, + webgpu_pipeline, + ggml_webgpu_flash_attn_vec_pipeline_key_hash> + flash_attn_vec_pipelines; std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash> flash_attn_pipelines; std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key, @@ -1099,14 +1183,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_flash_attn_blk_pipeline_key_hash> flash_attn_blk_pipelines; - std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key, - webgpu_pipeline, - ggml_webgpu_legacy_mul_mat_pipeline_key_hash> - mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash> - mul_mat_vec_pipelines; // fast mat-vec (n==1) + mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash> - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash> + quantize_q8_pipelines; std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash> mul_mat_id_pipelines; // src0_type/src1_type @@ -1235,10 +1317,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = {}; - key.dst_type = context.dst->type; - key.vec4 = context.src0->ne[0] % 4 == 0; - key.i64_idx = context.src1->type == GGML_TYPE_I64; + const bool quantized = ggml_is_quantized(context.dst->type); + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = + (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; + key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0); auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -1257,6 +1342,14 @@ class ggml_webgpu_shader_lib { defines.push_back("DST_F16"); variant += "_dstf16"; break; + case GGML_TYPE_Q8_0: + defines.push_back("DST_Q8_0"); + variant += "_dstq8_0"; + break; + case GGML_TYPE_Q4_0: + defines.push_back("DST_Q4_0"); + variant += "_dstq4_0"; + break; default: GGML_ABORT("Unsupported dst type for set_rows shader"); } @@ -1269,13 +1362,19 @@ class ggml_webgpu_shader_lib { defines.push_back("I64_IDX"); variant += "_i64idx"; } + if (key.pair_blocks) { + defines.push_back("PAIR_BLOCKS"); + variant += "_pair_blocks"; + } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_set_rows, defines); - auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); + const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows; + auto processed = preprocessor.preprocess(shader_source, defines); + auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); decisions->vec4 = key.vec4; decisions->i64_idx = key.i64_idx; + decisions->pair_blocks = key.pair_blocks; decisions->wg_size = context.max_wg_size; set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); set_rows_pipelines[key].context = decisions; @@ -1631,7 +1730,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1744,14 +1843,54 @@ class ggml_webgpu_shader_lib { return pad_pipelines[key]; } + webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_quantize_q8_pipeline_key key = {}; + key.src0_type = context.src0->type; + + auto it = quantize_q8_pipelines.find(key); + if (it != quantize_q8_pipelines.end()) { + return it->second; + } + const char * shader_src = wgsl_quantize_q8; + std::vector<std::string> defines; + std::string variant = "quantize_q8"; + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + + defines.push_back("SRC1_INNER_TYPE=f32"); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("Q8_1_T"); + + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + + auto processed = preprocessor.preprocess(shader_src, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + quantize_q8_pipelines[key] = pipeline; + return quantize_q8_pipelines[key]; + } + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; + key.use_mmvq = + ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1788,6 +1927,19 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); switch (context.src0->type) { + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + if (key.use_mmvq) { + defines.push_back("LEGACY_QUANTS"); + } + break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + if (key.use_mmvq) { + defines.push_back("K_QUANTS"); + } + break; case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_S: @@ -1840,6 +1992,11 @@ class ggml_webgpu_shader_lib { outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } + if (key.use_mmvq) { + defines.push_back("MMVQ"); + defines.push_back("Q8_1_T"); + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); @@ -1864,11 +2021,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_subgroup_matrix = context.supports_subgroup_matrix; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -2018,100 +2175,6 @@ class ggml_webgpu_shader_lib { return mul_mat_fast_pipelines[key]; } - webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - - auto it = mul_mat_legacy_pipelines.find(key); - if (it != mul_mat_legacy_pipelines.end()) { - return it->second; - } - - std::vector<std::string> defines; - std::string variant = "mul_mat"; - - switch (context.src1->type) { - case GGML_TYPE_F32: - defines.push_back("SRC1_TYPE=f32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC1_TYPE=f16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); - } - - const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); - const char * src0_name = src0_traits->type_name; - - switch (context.src0->type) { - case GGML_TYPE_F32: - defines.push_back("SRC0_TYPE=f32"); - defines.push_back("FLOAT"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC0_TYPE=f16"); - defines.push_back("FLOAT"); - variant += "_f16"; - break; - default: - { - std::string type_upper = src0_name; - std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - - switch (context.src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_MXFP4: - { - // Quantized types using u32 buffers for portability. - defines.push_back("SRC0_TYPE=u32"); - defines.push_back("U32_DEQUANT_HELPERS"); - break; - } - default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } - } - - defines.push_back("BYTE_HELPERS"); - defines.push_back(type_upper + "_T"); - defines.push_back(type_upper); - defines.push_back(type_upper + "_SCALE_MIN"); - defines.push_back(type_upper + "_TABLES"); - defines.push_back(type_upper + "_GRID"); - - variant += std::string("_") + src0_name; - break; - } - } - - auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); - - auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); - decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; - - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - mul_mat_legacy_pipelines[key] = pipeline; - return mul_mat_legacy_pipelines[key]; - } - webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = mul_mat_id_gather_pipelines.find(1); if (it != mul_mat_id_gather_pipelines.end()) { @@ -2135,10 +2198,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2258,10 +2321,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_vec_pipelines.find(key); if (it != mul_mat_id_vec_pipelines.end()) { @@ -2578,6 +2641,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_concat_pipeline_key key = {}; key.type = context.dst->type; + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -2600,11 +2664,17 @@ class ggml_webgpu_shader_lib { GGML_ABORT("Unsupported type for concat shader"); } + if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_concat, defines); - auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>(); decisions->wg_size = context.max_wg_size; + decisions->src_overlap = key.src_overlap; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; concat_pipelines[key] = pipeline; @@ -2651,119 +2721,62 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - const ggml_webgpu_flash_attn_decisions decisions = - ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); - GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); - ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - auto it = flash_attn_pipelines.find(key); - if (it != flash_attn_pipelines.end()) { - return it->second; - } - std::vector<std::string> defines; - std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : - "flash_attn"; - - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - switch (key.q_type) { - case GGML_TYPE_F32: - defines.push_back("Q_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("Q_F16"); - break; - default: - GGML_ABORT("Unsupported Q type for flash attention shader"); - } - variant += std::string("_q") + ggml_type_name(key.q_type); - - switch (key.dst_type) { - case GGML_TYPE_F32: - defines.push_back("DST_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("DST_F16"); - break; - default: - GGML_ABORT("Unsupported dst type for flash attention shader"); - } - variant += std::string("_dst") + ggml_type_name(key.dst_type); - - if (key.has_mask) { - defines.push_back("MASK"); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("BLK"); - variant += "_mask_blk"; - } else { - variant += "_mask"; + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); + ggml_webgpu_flash_attn_decisions decisions = {}; + decisions.use_sg_matrix = can_use_subgroup_matrix; + decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.common = + ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u); + key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct; + key.use_sg_matrix = decisions.use_sg_matrix; + + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u, + key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + decisions.kv_tile = decisions.use_sg_matrix ? + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) : + std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile); + decisions.wg_size = + decisions.use_sg_matrix ? + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) : + std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)); + + if (key.common.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size; } } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - if (key.kv_overlap) { - defines.push_back("KV_OVERLAP"); - variant += "_kv_overlap"; - } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } - const char * shader_src = wgsl_flash_attn; - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("KV_GRANULARITY=8"); - defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); - shader_src = wgsl_flash_attn_vec_split; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile"; + std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile, + decisions.kv_tile, decisions.wg_size); + const char * shader_src = nullptr; + if (!key.use_sg_matrix) { shader_src = wgsl_flash_attn_tile; defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); - defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + std::to_string(context.max_subgroup_size); } else { + shader_src = wgsl_flash_attn; defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - - auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); - pipeline_decisions->kv_overlap = key.kv_overlap; - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); - + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); pipeline.context = pipeline_decisions; @@ -2771,6 +2784,55 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_pipeline_key key = {}; + key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; + } + + ggml_webgpu_flash_attn_vec_decisions decisions = {}; + decisions.kv_tile = + ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk, + key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + decisions.wg_size = context.max_subgroup_size; + + std::string variant = "flash_attn_vec"; + std::vector<std::string> defines = + ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size); + if (key.common.has_mask) { + defines.push_back("BLK"); + variant.resize(variant.size() - (sizeof("_mask") - 1)); + variant += "_mask_blk"; + } + uint32_t vec_ne = 1u; + if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 && + key.common.head_dim_qk == key.common.head_dim_v) { + switch (key.common.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions); + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = pipeline_decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { ggml_webgpu_flash_attn_blk_pipeline_key key = {}; key.kv_tile = kv_tile; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 921c12b41ac..0b605fa86ba 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -94,14 +94,6 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) { #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 -// For operations which process a row in parallel, this seems like a reasonable -// default -#define WEBGPU_ROW_SPLIT_WG_SIZE 64 - -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to -// implementations so this can be removed, necessary only for get_rows right now -#define WEBGPU_MAX_WG_SIZE 288 - /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to @@ -181,6 +173,7 @@ struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroups = false; bool supports_subgroup_matrix = false; + bool supports_dot_product = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -210,6 +203,8 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + std::string vendor; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -259,6 +254,7 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_host_error_buf; wgpu::CommandEncoder active_command_encoder; wgpu::ComputePassEncoder active_compute_pass; + bool batch_compute_passes = true; size_t memset_bytes_per_thread; @@ -590,9 +586,18 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & } #else for (size_t i = 0; i < dispatches.size(); i++) { - ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); - ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); - ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + if (ctx->batch_compute_passes) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, + 1); + } else { + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(); + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + } } #endif @@ -616,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = + ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); @@ -736,8 +742,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, @@ -961,9 +970,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + uint32_t wg_x; + uint32_t wg_y; uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } @@ -1051,9 +1061,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + uint32_t wg_x; + uint32_t wg_y; uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } @@ -1234,7 +1245,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; - const uint32_t K = (uint32_t) src5->ne[1]; + const uint32_t K = (uint32_t) ggml_get_op_params_i32(dst, 0); const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); @@ -1321,7 +1332,11 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context & ct } uint32_t threads; - if (decisions->vec4) { + if (ggml_is_quantized(dst->type)) { + const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type); + threads = + (src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row); + } else if (decisions->vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; @@ -1348,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); @@ -1384,6 +1399,58 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + std::vector<webgpu_dispatch_desc> & dispatches) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx); + + // quantize_q8 pipeline + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t q8_src1_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t q8_src1_binding_size = + ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); + + std::vector<uint32_t> q8_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector<wgpu::BindGroupEntry> q8_entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size) + }; + + auto q8_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(qq8_pipeline.context.get()); + + uint32_t q8_wg_size = q8_decisions->wg_size; + uint32_t q8_wg_x = 1; + uint32_t q8_wg_y = 1; + const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; + const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); + + dispatches.push_back({ + qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y } + }); +} + static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -1391,47 +1458,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); - // Determine if we should use fast path - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q6_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q1_0: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_MXFP4: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } + // use MMVQ path for mat-vec + bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, + ctx->global_ctx->vendor); ggml_webgpu_shader_lib_context shader_lib_ctx = {}; @@ -1446,16 +1475,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product; + shader_lib_ctx.vendor = ctx->global_ctx->vendor; // Get or create pipeline - webgpu_pipeline pipeline; + webgpu_pipeline pipeline; + std::vector<webgpu_dispatch_desc> dispatches; - if (use_fast && is_vec) { + if (is_vec) { + if (use_mmvq) { + ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); + } pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); - } else if (use_fast) { - pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } else { - pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } // Build params @@ -1479,25 +1512,31 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Build bind group entries - std::vector<wgpu::BindGroupEntry> entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), - }; + std::vector<wgpu::BindGroupEntry> entries = {}; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + if (use_mmvq) { + auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1]; + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset, + mmvq_qq8_entry.size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); // Calculate workgroup dimensions uint32_t wg_x = 1; uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (use_fast && is_vec) { + if (is_vec) { auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - } else if (use_fast) { + } else { auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get()); // Fast-path tiled/subgroup calculations @@ -1518,15 +1557,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, } uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - - } else { // legacy - auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); - uint32_t wg_size = decisions->wg_size; - uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); + dispatches.push_back({ + pipeline, std::move(params), std::move(entries), { wg_x, wg_y } + }); + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, @@ -1654,14 +1691,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, gathered_count_ids_binding_size), }; - const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - - const uint32_t gather_total_wg = param_n_expert; - const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); - const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); + // n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B) + const uint32_t gather_wg_x = param_n_expert; dispatches.push_back({ - gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y } + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 } }); // params for mul_mat_id.wgsl @@ -1713,7 +1747,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts; uint32_t total_wg = wg_m * max_wg_n; - compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); dispatches.push_back({ main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } @@ -1722,13 +1756,50 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +struct ggml_webgpu_flash_attn_op { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + std::vector<uint32_t> params; + std::vector<wgpu::BindGroupEntry> entries; + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + bool has_mask = false; + bool has_sinks = false; + bool kv_overlap = false; +}; + +static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const bool k_float_vec4_aligned = (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment); + const bool v_float_vec4_aligned = (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); + const bool k_vec_type_supported = + K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool v_vec_type_supported = + V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; + const uint32_t k_vec_head_align = (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const uint32_t v_vec_head_align = (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(V->type); + const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; + + return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && + kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned && + v_float_vec4_aligned; +} + +static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = ggml_get_op_params_f32(dst, 0); float max_bias = ggml_get_op_params_f32(dst, 1); float logit_softcap = ggml_get_op_params_f32(dst, 2); @@ -1739,47 +1810,43 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = Q; - shader_lib_ctx.src1 = K; - shader_lib_ctx.src2 = V; - shader_lib_ctx.src3 = mask; - shader_lib_ctx.src4 = sinks; - shader_lib_ctx.dst = dst; - shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( - shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get()); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); - const bool kv_overlap = decisions->kv_overlap; - - uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - size_t kv_bind_offset = 0; - size_t kv_bind_size = 0; - if (kv_overlap) { + ggml_webgpu_flash_attn_op op = {}; + op.shader_lib_ctx.src0 = Q; + op.shader_lib_ctx.src1 = K; + op.shader_lib_ctx.src2 = V; + op.shader_lib_ctx.src3 = mask; + op.shader_lib_ctx.src4 = sinks; + op.shader_lib_ctx.dst = dst; + op.shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + op.shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + op.shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + op.shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + op.shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + op.shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + op.shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + op.shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; + op.shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + + op.has_mask = mask != nullptr; + op.has_sinks = sinks != nullptr; + op.kv_overlap = ggml_webgpu_tensor_overlap(K, V); + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + if (op.kv_overlap) { const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); - kv_bind_offset = merged_range.offset; - kv_bind_size = merged_range.size; + op.kv_bind_offset = merged_range.offset; + op.kv_bind_size = merged_range.size; offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); } - std::vector<uint32_t> params = { + op.params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), offset_k, offset_v, - has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, - has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, + op.has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + op.has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) Q->ne[2], // number of heads (uint32_t) Q->ne[1], // sequence length (Q) @@ -1793,7 +1860,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + op.has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) ggml_webgpu_u32_from_f32(max_bias), @@ -1801,32 +1868,56 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_u32_from_f32(n_head_log2), ggml_webgpu_u32_from_f32(m0), ggml_webgpu_u32_from_f32(m1) - }; - std::vector<wgpu::BindGroupEntry> entries = { + op.entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), }; - if (kv_overlap) { - entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + if (op.kv_overlap) { + op.entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); } - uint32_t binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); + uint32_t binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } - if (has_sinks) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + if (op.has_sinks) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); } - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return op; +} + +static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv_tile, uint32_t seq_len_kv) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) kv_tile; + while ((2u * nwg * kv_span) < (uint64_t) seq_len_kv && nwg < vec_nwg_cap) { + nwg <<= 1; } + return std::min(nwg, vec_nwg_cap); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; + return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst, + ggml_webgpu_flash_attn_op op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1835,13 +1926,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_batch_count = 0; const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; const bool use_vec_reduce = nwg > 1u; GGML_ASSERT(nrows <= UINT32_MAX); @@ -1877,7 +1963,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, webgpu_pipeline blk_pipeline; std::vector<uint32_t> blk_params; std::vector<wgpu::BindGroupEntry> blk_entries; - if (has_mask) { + if (op.has_mask) { blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); blk_nblk1 = (uint32_t) Q->ne[1]; blk_buf = ggml_webgpu_tensor_buf(dst); @@ -1885,7 +1971,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + const ggml_webgpu_shader_lib_context blk_shader_ctx = op.shader_lib_ctx; blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); blk_params = { @@ -1905,8 +1991,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); } - std::vector<uint32_t> split_params = params; - if (has_mask) { + std::vector<uint32_t> split_params = op.params; + if (op.has_mask) { split_params.push_back(0u); // blk_base split_params.push_back(blk_nblk0); // blk_nblk0 split_params.push_back(blk_nblk1); // blk_nblk1 @@ -1919,9 +2005,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), ggml_webgpu_tensor_binding_size(ctx, Q)), }; - if (kv_overlap) { + if (op.kv_overlap) { split_entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), @@ -1930,18 +2016,18 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_tensor_align_offset(ctx, V), ggml_webgpu_tensor_binding_size(ctx, V))); } - uint32_t split_binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { + uint32_t split_binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), ggml_webgpu_tensor_align_offset(ctx, mask), ggml_webgpu_tensor_binding_size(ctx, mask))); } - if (has_sinks) { + if (op.has_sinks) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), ggml_webgpu_tensor_align_offset(ctx, sinks), ggml_webgpu_tensor_binding_size(ctx, sinks))); } - if (has_mask) { + if (op.has_mask) { split_entries.push_back( ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); } @@ -1956,11 +2042,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector<wgpu::BindGroupEntry> reduce_entries; if (use_vec_reduce) { const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; - const uint32_t reduce_wg_size = - std::max(reduce_sg_size, (uint32_t) std::min<uint64_t>( - (uint64_t) nwg * reduce_sg_size, - ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + const uint32_t reduce_wg_size = std::max( + reduce_sg_size, + (uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = op.shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -1987,7 +2073,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector<webgpu_dispatch_desc> dispatches; - if (has_mask) { + if (op.has_mask) { dispatches.push_back({ blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } }); @@ -2004,6 +2090,20 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + ggml_webgpu_flash_attn_op op = ggml_webgpu_flash_attn_prepare(ctx, Q, K, V, mask, sinks, dst); + if (ggml_webgpu_flash_attn_use_vec_path(ctx->global_ctx, Q, K, V)) { + return ggml_webgpu_flash_attn_vec(ctx, Q, K, V, mask, sinks, dst, std::move(op)); + } + return ggml_webgpu_flash_attn_direct(ctx, op); +} + static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -2070,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, @@ -2145,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, } } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx, @@ -2206,33 +2310,6 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; - std::vector<uint32_t> params = { - ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), - (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], - (uint32_t) dst->ne[3], - dim, - (uint32_t) src0->ne[dim] - }; - - std::vector<wgpu::BindGroupEntry> entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), - }; - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; @@ -2240,8 +2317,52 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); - auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get()); + + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + } + + std::vector<uint32_t> params = { ne, + offset_src0, + offset_src1, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + dim, + (uint32_t) src0->ne[dim] }; + + std::vector<wgpu::BindGroupEntry> entries = {}; + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -2574,8 +2695,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, @@ -2736,10 +2859,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * block_size, npr, nrows }; - const uint32_t total_wg_init = npr * nrows; - const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - const uint32_t wg_x_init = std::min(total_wg_init, max_wg); - const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); + uint32_t wg_x_init; + uint32_t wg_y_init; + const uint32_t total_wg_init = npr * nrows; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init); + std::vector<wgpu::BindGroupEntry> init_entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) @@ -2796,9 +2921,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; + uint32_t wg_x_merge; + uint32_t wg_y_merge; const uint32_t total_wg_merge = nm * nrows; - const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); - const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); + compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge); + dispatches.push_back({ argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } }); @@ -2918,9 +3045,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); - uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } @@ -3110,18 +3240,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str uint32_t num_batched_kernels = 0; uint32_t num_inflight_batches = 0; bool contains_set_rows = false; - bool batch_compute_passes = true; int num_encoded_ops = 1; int node_idx = 0; #ifdef GGML_WEBGPU_GPU_PROFILE ctx->profile_timestamp_query_count = 0; - batch_compute_passes = false; std::vector<std::string> profile_pipeline_names; #endif ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); - if (batch_compute_passes) { + if (ctx->batch_compute_passes) { ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } @@ -3148,7 +3276,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str // reset state for next batch ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); - if (batch_compute_passes) { + if (ctx->batch_compute_passes) { ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } ctx->param_arena.reset(); @@ -3515,70 +3643,59 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer break; case GGML_OP_FLASH_ATTN_EXT: { - const ggml_tensor * Q = tensor->src[0]; - const ggml_tensor * K = tensor->src[1]; - const ggml_tensor * V = tensor->src[2]; - const ggml_tensor * mask = tensor->src[3]; - const ggml_tensor * sinks = tensor->src[4]; - if (Q && K && V) { - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = const_cast<ggml_tensor *>(Q); - shader_lib_ctx.src1 = const_cast<ggml_tensor *>(K); - shader_lib_ctx.src2 = const_cast<ggml_tensor *>(V); - shader_lib_ctx.src3 = const_cast<ggml_tensor *>(mask); - shader_lib_ctx.src4 = const_cast<ggml_tensor *>(sinks); - shader_lib_ctx.dst = const_cast<ggml_tensor *>(tensor); - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const uint32_t kv_tile = decisions.kv_tile; - - const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - - const size_t align = - ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - if (nwg > 1u) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( - (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += tmp_size_bytes + align; - } else { - res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; - } - if (mask != nullptr) { - const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); - const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - const size_t blk_size_bytes = - ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += blk_size_bytes + align; - } - res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { + const bool kv_direct = + ggml_webgpu_flash_attn_kv_direct(Q, K, V, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0], + mask != nullptr, kv_direct); + + const uint32_t vec_nwg_cap = capabilities.min_subgroup_size; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]); + + const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), + WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; + case GGML_OP_MUL_MAT: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + bool use_mmvq = + ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, + ctx->webgpu_global_ctx->vendor); + if (use_mmvq) { + const size_t q8_src1_size = + src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + res = ROUNDUP_POW2(res + q8_src1_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); } } break; @@ -3658,19 +3775,20 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * + ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector<wgpu::ConstantEntry> constants(2); constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[0].value = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; constants[1].key = "bytes_per_thread"; constants[1].value = ctx->capabilities.memset_bytes_per_thread; ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { wgpu::RequestAdapterOptions options = {}; #ifndef __EMSCRIPTEN__ @@ -3707,12 +3825,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetInfo(&info); ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); - wgpu::SupportedFeatures features; - ctx->webgpu_global_ctx->adapter.GetFeatures(&features); - // we require f16 support - GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + ctx->webgpu_global_ctx->vendor = info.vendor; ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + // for dot4I8packed + ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature( + wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct); bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ @@ -3819,7 +3937,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { "device_desc: %s\n", info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, std::string(info.device).c_str(), std::string(info.description).c_str()); - return true; } static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { @@ -3839,6 +3956,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); #ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_ctx->batch_compute_passes = false; ggml_webgpu_create_buffer( webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); @@ -3991,8 +4109,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); break; case GGML_OP_SET_ROWS: - supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); + supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 || + op->type == GGML_TYPE_Q4_0) && + src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); break; case GGML_OP_GET_ROWS: if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) { @@ -4084,70 +4203,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { + // conservative support checks for whether the more resource-intensive shader paths + // can be used, to avoid cases where flash_attn is assigned to the CPU later on supports_op = src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; + (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16 || + src2->type == GGML_TYPE_Q4_0 || src2->type == GGML_TYPE_Q8_0) && + op->type == GGML_TYPE_F32; if (!supports_op) { break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = src0; - shader_lib_ctx.src1 = src1; - shader_lib_ctx.src2 = src2; - shader_lib_ctx.src3 = op->src[3]; - shader_lib_ctx.src4 = op->src[4]; - shader_lib_ctx.dst = const_cast<ggml_tensor *>(op); - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + if (ggml_webgpu_tensor_overlap(src1, src2) && src1->type != src2->type && + !ggml_is_quantized(src1->type) && !ggml_is_quantized(src2->type)) { supports_op = false; break; } - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; + + // subgroup matrix path requirements + const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); + + // tile path requirements + const bool float_vec4_aligned = + ((src1->type != GGML_TYPE_F16 && src1->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) && + ((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment)); + const uint32_t k_tile_head_align = (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src1->type); + const uint32_t v_tile_head_align = (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src2->type); + const bool tile_kv_head_dims_aligned = + src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0; + const bool tile_can_dispatch_all_q_rows = + capabilities.limits.maxComputeInvocationsPerWorkgroup >= + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; + const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && float_vec4_aligned && + tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows; + + if (!use_subgroup_matrix && !use_tile) { supports_op = false; break; } - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } + const uint32_t q_tile = + use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); + supports_op = max_kv_tile > 0; break; } case GGML_OP_RMS_NORM: @@ -4447,7 +4559,12 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { UINT64_MAX); } - if (adapter != nullptr) { + // WebGPU backend requires f16 support and, on native, implicit device synchronization. + if (adapter != nullptr && adapter.HasFeature(wgpu::FeatureName::ShaderF16) +#ifndef __EMSCRIPTEN__ + && adapter.HasFeature(wgpu::FeatureName::ImplicitDeviceSynchronization) +#endif + ) { ctx->device_count = 1; } @@ -4455,8 +4572,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } ggml_backend_t ggml_backend_webgpu_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - + ggml_backend_reg_t reg = ggml_backend_webgpu_reg(); + if (ggml_backend_reg_dev_count(reg) == 0) { + return nullptr; + } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, 0); return ggml_backend_webgpu_backend_init(dev, nullptr); } diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp index 4d4359463ca..fb41a961d74 100644 --- a/ggml/src/ggml-webgpu/pre_wgsl.hpp +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -37,15 +37,33 @@ static std::string trim(const std::string & s) { } static std::string trim_value(std::istream & is) { - std::string str; - std::getline(is, str); - return trim(str); + std::ostringstream ss; + ss << is.rdbuf(); + return trim(ss.str()); } static bool isIdentChar(char c) { return std::isalnum(static_cast<unsigned char>(c)) || c == '_'; } +static bool endsWithContinuation(const std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + return i > 0 && line[i - 1] == '\\'; +} + +static void stripContinuation(std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + if (i > 0 && line[i - 1] == '\\') { + line.erase(i - 1); + } +} + static std::string expandMacrosRecursiveInternal(const std::string & line, const std::unordered_map<std::string, std::string> & macros, std::unordered_set<std::string> & visiting); @@ -595,19 +613,31 @@ class Preprocessor { std::string line; while (std::getline(in, line)) { - std::string t = trim(line); + std::string logical = line; + std::string t = trim(logical); + if (!t.empty() && t[0] == '#') { + while (endsWithContinuation(logical)) { + stripContinuation(logical); + if (!std::getline(in, line)) { + break; + } + logical += "\n"; + logical += line; + } + t = trim(logical); + } if (!t.empty() && t[0] == '#') { bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); if (mode == DirectiveMode::IncludesOnly && !handled) { - out << line << "\n"; + out << logical << "\n"; } } else { if (mode == DirectiveMode::IncludesOnly) { - out << line << "\n"; + out << logical << "\n"; } else if (condActive(cond)) { // Expand macros in the line before outputting - std::string expanded = expandMacrosRecursive(line, macros); + std::string expanded = expandMacrosRecursive(logical, macros); out << expanded << "\n"; } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index 605de7aa7be..f262c4a8f6a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x < params.ne) { - let src0_i = params.offset_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + src1_index(gid.x); - update(params.offset_dst + gid.x, src0_i, src1_i); +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i < params.ne) { + let src0_i = params.offset_src0 + src0_index(i); + let src1_i = params.offset_src1 + src1_index(i); + update(params.offset_dst + i, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 372ea79bf9d..758efa17d77 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -95,11 +95,10 @@ struct q5_1 { }; #endif - #ifdef Q8_1_T struct q8_1 { d: f16, - m: f16, + s: f16, // d * sum(qs[i]) qs: array<u32, 8> }; #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl index a22d245d2cc..eb901bf0547 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl @@ -31,6 +31,16 @@ struct Params { #define DataType i32 #endif +#ifdef SRC_OVERLAP +@group(0) @binding(0) +var<storage, read_write> merged_src: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(2) +var<uniform> params: Params; +#else @group(0) @binding(0) var<storage, read_write> src0: array<DataType>; @@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>; @group(0) @binding(3) var<uniform> params: Params; - +#endif @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { @@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { ni[1] * params.stride_src0_1 + ni[2] * params.stride_src0_2 + ni[3] * params.stride_src0_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i]; +#else dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; +#endif } else { ni[params.dim] -= params.src0_nedim; let src_i = ni[0] * params.stride_src1_0 + ni[1] * params.stride_src1_1 + ni[2] * params.stride_src1_2 + ni[3] * params.stride_src1_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i]; +#else dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; +#endif } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index fa3bdf4e393..67f1dc0928f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -49,7 +49,9 @@ struct Params{ var<uniform> params: Params; @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, +) { if (gid.x >= params.ne) { return; } @@ -78,4 +80,3 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 6d5d69fb8de..9767ca3d754 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,12 +4,23 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; -#ifdef KV_F32 -#define KV_TYPE f32 -#elif defined(KV_Q4_0) || defined(KV_Q8_0) -#define KV_TYPE u32 +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else -#define KV_TYPE f16 +#define V_TYPE f16 #endif // Default values @@ -30,76 +41,6 @@ enable chromium_experimental_subgroup_matrix; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. #define KV_BLOCKS (KV_TILE / SG_MAT_N) -// Quantization constants/helpers -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread -#if defined(KV_Q4_0) -#define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights -#define F16_PER_BLOCK 9 -#define BLOCK_SIZE_BYTES 18u -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights -#define F16_PER_BLOCK 17 -#define BLOCK_SIZE_BYTES 34u -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -// Ok not to put these in a define block, compiler will remove if unused -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - -#if defined(KV_Q4_0) || defined(KV_Q8_0) -fn load_k_u16_at(byte_offset: u32) -> u32 { - let word = K[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_k_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = K[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = K[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} - -fn load_v_u16_at(byte_offset: u32) -> u32 { - let word = V[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_v_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = V[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = V[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} - -fn f16_from_u16(bits: u32) -> f16 { - let packed = unpack2x16float(bits); - return f16(packed[0]); -} -#endif - struct Params { offset_q: u32, offset_k: u32, @@ -139,11 +80,11 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<f32>; #ifdef KV_OVERLAP -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #define V K #else -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; -@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>; +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; #endif #if defined(MASK) && defined(SINKS) @@ -238,10 +179,47 @@ fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32 return (*buf)[scalar_index >> 2u]; } -fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> { +fn load_kx4(buf: ptr<storage, array<vec4<K_TYPE>>, read_write>, scalar_index: u32) -> vec4<K_TYPE> { return (*buf)[scalar_index >> 2u]; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + kv_shmem[elem_idx] = f16(select( + 0.0, + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); + } +} +#endif +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, @@ -311,77 +289,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); // clear inter_shmem to ensure zero-initialized accumulators for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { inter_shmem[elem_idx] = 0.0; } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - kv_shmem[elem_idx] = f16(select( - 0.0, - K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -520,71 +436,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - kv_shmem[elem_idx] = f16(select( - 0.0, - V[global_v_row_offset + v_col], - global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl new file mode 100644 index 00000000000..8f41eb7bfdb --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl @@ -0,0 +1,124 @@ +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) + +#if defined(K_Q4_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 18u +#define K_BYTES_PER_THREAD 8u +#define K_BYTES_PER_INNER_LOOP 4u +#elif defined(K_Q8_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 34u +#define K_BYTES_PER_THREAD 16u +#define K_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(V_Q4_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 18u +#define V_BYTES_PER_THREAD 8u +#define V_BYTES_PER_INNER_LOOP 4u +#elif defined(V_Q8_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 34u +#define V_BYTES_PER_THREAD 16u +#define V_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; + let q_packed = load_k_u32_at(q_byte_offset); +#if defined(K_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#elif defined(K_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; + let q_packed = load_v_u32_at(q_byte_offset); +#if defined(V_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#elif defined(V_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 4133f0ab564..e68934113fc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -1,16 +1,29 @@ enable f16; enable subgroups; +#define BYTE_HELPERS +#include "common_decls.tmpl" + #ifdef Q_F16 #define Q_TYPE f16 #else #define Q_TYPE f32 #endif -#ifdef KV_F32 -#define KV_TYPE f32 +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else -#define KV_TYPE f16 +#define V_TYPE f16 #endif #ifdef DST_F16 @@ -21,7 +34,6 @@ enable subgroups; #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 -#define KV_STAGE_STRIDE 64 #define Q_TILE 4 #define KV_TILE 64 #define WG_SIZE 128 @@ -64,11 +76,23 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; #ifdef KV_OVERLAP -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif #define V K #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; -@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; +#else +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; +#endif #endif #if defined(MASK) && defined(SINKS) @@ -121,10 +145,50 @@ const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>; -var<workgroup> kv_shmem: array<KV_TYPE, KV_TILE * KV_STAGE_STRIDE>; -var<workgroup> p_shmem: array<KV_TYPE, Q_TILE * KV_TILE>; +var<workgroup> kv_shmem: array<f16, kv_shmem_size>; +var<workgroup> p_shmem: array<f16, Q_TILE * KV_TILE>; + +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(k4.x); + kv_shmem[kv_off + 1u] = f16(k4.y); + kv_shmem[kv_off + 2u] = f16(k4.z); + kv_shmem[kv_off + 3u] = f16(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(v4.x); + kv_shmem[kv_off + 1u] = f16(v4.y); + kv_shmem[kv_off + 2u] = f16(v4.z); + kv_shmem[kv_off + 3u] = f16(v4.w); + } +} +#endif @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @@ -206,18 +270,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, local_scores[slot] = FLOAT_MIN; } - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / Q_CHUNKS; - let chunk = vec_idx_local % Q_CHUNKS; - let global_k_row = kv_tile + kv_local; - let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; - let k4 = K[k_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); +#endif workgroupBarrier(); @@ -238,8 +293,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4<KV_TYPE>( + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + let kv = vec4<f16>( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], @@ -271,25 +326,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); + p_shmem[subgroup_p_offset + kv_local] = f16(p); local_sum += p; } } workgroupBarrier(); - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / V_CHUNKS; - let chunk = vec_idx_local % V_CHUNKS; - let global_v_row = kv_tile + kv_local; - let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; - let v4 = V[v_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); +#endif workgroupBarrier(); @@ -306,14 +352,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, var acc = out_regs[reg_idx]; for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { - let p = p_shmem[subgroup_p_offset + kv_local]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4<KV_TYPE>( + let p = f32(p_shmem[subgroup_p_offset + kv_local]); + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + let v4 = vec4<f16>( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += f32(p) * vec4<f32>(v4); + acc += p * vec4<f32>(v4); } out_regs[reg_idx] = acc; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 30ebbebe772..30ed97cca0c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -2,10 +2,23 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -#ifdef KV_F32 -#define KV_TYPE f32 +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 #else -#define KV_TYPE f16 +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 +#else +#define V_TYPE f16 #endif #ifdef Q_F16 @@ -32,28 +45,6 @@ enable subgroups; #define KV_BLOCKS (KV_TILE / KV_GRANULARITY) -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -#if defined(KV_Q4_0) -#define NQ 16 -#define F16_PER_BLOCK 9 -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -#define F16_PER_BLOCK 17 -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - struct Params { offset_q: u32, offset_k: u32, @@ -103,22 +94,22 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; #ifdef KV_OVERLAP -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; #endif #define V K #else -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; #endif -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>; +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; #else -@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>; +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; #endif #endif #if defined(MASK) && defined(SINKS) @@ -244,6 +235,49 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) return v; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f32 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4<K_TYPE>(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4<V_TYPE>(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); + } +} +#endif +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, @@ -308,6 +342,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); #ifdef BLK let q_blk = q_row_start; let kv_blk = kv_tile / KV_TILE; @@ -324,76 +359,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; - let vec_idx = (global_k_row_offset + k_col) >> 2u; - let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(k4.x); - kv_shmem[elem_idx + 1u] = f32(k4.y); - kv_shmem[elem_idx + 2u] = f32(k4.z); - kv_shmem[elem_idx + 3u] = f32(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -510,76 +477,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; - let vec_idx = (global_v_row_offset + v_col) >> 2u; - let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(v4.x); - kv_shmem[elem_idx + 1u] = f32(v4.y); - kv_shmem[elem_idx + 2u] = f32(v4.z); - kv_shmem[elem_idx + 3u] = f32(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index d68520f8282..7d7b3475549 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -63,10 +63,10 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + let state_in_base = (seq_id * params.h + head_id) * state_size; let state_out_base = (seq_id * params.h + head_id) * state_size; let state_size_per_snap = state_size * params.h * params.n_seqs; - let shift = i32(params.n_tokens) - i32(params.K); var state: array<f32, S_V>; for (var i = 0u; i < S_V; i++) { @@ -128,7 +128,8 @@ fn main( attn_off += S_V * params.h; if (params.K > 1u) { - let target_slot = i32(t) - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + let target_slot = i32(params.n_tokens) - 1 - i32(t); if (target_slot >= 0 && target_slot < i32(params.K)) { let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; for (var i = 0u; i < S_V; i++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl deleted file mode 100644 index fcbefdeb802..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ /dev/null @@ -1,747 +0,0 @@ -enable f16; - -#define DECLARE_BYTE_LOADERS_SRC0 -#include "common_decls.tmpl" - - -#ifdef FLOAT -const BLOCK_SIZE = 1u; - -#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) -const BLOCK_SIZE = 32u; - -#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) -const BLOCK_SIZE = 256u; -#endif - -#ifdef FLOAT -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); -} -#endif - -#ifdef Q4_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q4_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_1 = src0[src0_idx_base + offset]; - let d = f32(block_q4_1.d); - let m = f32(block_q4_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q4_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - let qh_packed = load_u32_at_src0(block_byte_base + 2); - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_1 = src0[src0_idx_base + offset]; - let d = f32(block_q5_1.d); - let m = f32(block_q5_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q5_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; - let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q8_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q8_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_1 = src0[src0_idx_base + offset]; - let d = f32(block_q8_1.d); - let m = f32(block_q8_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = block_q8_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q2_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(block.scales[is / 4], is % 4); - is++; - let dl = d * f32(sc & 0xF); - let ml = m * f32(sc >> 4); - for (var l: u32 = 0u; l < 16; l++) { - let q_idx = q_b_idx + k + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 3; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - } - return sum; -} -#endif - -#ifdef Q3_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - - // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 108); - - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes - // Bytes 96-107: 12 bytes of scales (3 u32s) - let kmask1: u32 = 0x03030303; - let kmask2: u32 = 0x0f0f0f0f; - var scale_vals: array<u32, 4>; - scale_vals[0] = load_u32_at_src0(block_byte_base + 96); - scale_vals[1] = load_u32_at_src0(block_byte_base + 100); - scale_vals[2] = load_u32_at_src0(block_byte_base + 104); - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - // Bytes 0-31: 32 bytes of hmask (8 u32s) - var hmask_vals: array<u32, 8>; - for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 32-95: 64 bytes of qs (16 u32s) - var qs_vals: array<u32, 16>; - for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var m: u32 = 1; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(scale_vals[is / 4], is % 4); - is++; - let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3; - sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; - src1_i++; - } - } - m <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q4_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 0xF; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef Q5_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var u: u32 = 1; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qh_byte = get_byte(block.qh[l / 4], l % 4); - let qs_val = (q_byte >> shift) & 0xF; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - u <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q6_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes - - // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 208); - - // Bytes 0-127: 128 bytes of ql (32 u32s) - var ql_vals: array<u32, 32>; - for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 128-191: 64 bytes of qh (16 u32s) - var qh_vals: array<u32, 16>; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); - } - - // Bytes 192-207: 16 bytes of scales (4 u32s) - var scale_vals: array<u32, 4>; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var qh_b_idx: u32 = 0; - var sc_b_idx: u32 = 0; - for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { - for (var l: u32 = 0; l < 32; l++) { - let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); - let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); - let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); - - let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; - let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; - let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; - let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; - - let is = l/16; - let is1 = sc_b_idx + is; - let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); - let is2 = sc_b_idx + is + 2; - let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); - let is3 = sc_b_idx + is + 4; - let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); - let is4 = sc_b_idx + is + 6; - let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); - - sum += d * f32(sc1) * q1 * src1[src1_i + l]; - sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; - sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; - sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; - } - src1_i += 128; - qh_b_idx += 32; - sc_b_idx += 8; - } - return sum; -} -#endif - -#ifdef IQ2_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0_offset = block_byte_base + 2 + ib * 2; - let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at_src0(aux0_offset); - let aux1 = load_u32_at_src0(aux1_offset); - let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; - for (var l: u32 = 0; l < 4; l++) { - let ig = get_byte(aux0, l) * 8; - let is = (aux1 >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += db * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var scale_vals = array<u32, 2>( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); - let db = array<f32, 2>( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - for (var l: u32 = 0; l < 4; l++) { - let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; - let ig = (qs_val & 511) * 8; - let is = qs_val >> 9; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qs_vals : array<u32, 16>; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - - var qh_vals: array<u32, 2>; - qh_vals[0] = load_u32_at_src0(block_byte_base + 66); - qh_vals[1] = load_u32_at_src0(block_byte_base + 70); - - var scale_vals: array<u32, 2>; - scale_vals[0] = load_u32_at_src0(block_byte_base + 74); - scale_vals[1] = load_u32_at_src0(block_byte_base + 78); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib ++) { - let s = get_byte(scale_vals[ib / 4], ib % 4); - let db = array<f32, 2>( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - let qs_w = qs_vals[ib]; - for (var l: u32 = 0; l < 4; l++) { - let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; - let ig = (get_byte(qs_w, l) | qh_b) * 8; - let signs = get_byte(qs_vals[ib + 8], l); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ3_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at_src0(sc_sign_offset); - let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; - for (var l: u32 = 0; l < 4; l++) { - let is = (sc_sign >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0); - let ig2 = get_byte(ig_val, 1); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3xxs_grid[ig1], j); - let g2 = get_byte(iq3xxs_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += db * f32(g1) * m1 * src1[src1_i]; - sum += db * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - return sum; -} -#endif - -#ifdef IQ3_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qh_vals = array<u32, 2>( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sign_vals: array<u32, 8>; - for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); - } - - var scale_vals = load_u32_at_src0(block_byte_base + 106); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 4; ib++) { - let s = get_byte(scale_vals, ib); - let db = array<f32, 2>( - d * (1.0 + 2.0 * f32(s & 0xF)), - d * (1.0 + 2.0 * f32(s >> 4)) - ); - for (var k: u32 = 0; k < 2; k++) { - let dl = db[k]; - let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); - let sign_w = sign_vals[ib * 2 + k]; - for (var l: u32 = 0; l < 4; l++) { - let signs = get_byte(sign_w, l); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); - let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3s_grid[ig1], j); - let g2 = get_byte(iq3s_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += dl * f32(g1) * m1 * src1[src1_i]; - sum += dl * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - } - return sum; -} -#endif - -#ifdef IQ1_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; - let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); - for (var l: u32 = 0; l < 4; l++) { - let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast<i32>(g << 30) >> 30; - sum += dl * (f32(gs) + delta) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - - -#ifdef IQ1_M -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - - let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); - let d = f32(bitcast<vec2<f16>>(scale).x); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; - let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; - let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; - var dl = array<f32, 2>( - d * f32(2 * s1 + 1), - d * f32(2 * s2 + 1) - ); - - let qh = block.qh[ib / 2] >> (16 * (ib % 2)); - var idx = array<u32, 4>( - get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), - get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), - get_byte(block.qs[ib], 2) | ((qh) & 0x700), - get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) - ); - var delta = array<f32, 4>( - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) - ); - for (var l: u32 = 0; l < 4; l++) { - let ig = idx[l] * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast<i32>(g << 30) >> 30; - sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ4_NL -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 32; - var sum = 0.0; - var qs: array<u32, 4>; - for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - for (var j: u32 = 0; j < 16; j++) { - let qsb = get_byte(qs[j / 4], j % 4); - sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - return sum; -} -#endif - -#ifdef IQ4_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = unpack2x16float(block.d_scales_h)[0]; - let scales_h = block.d_scales_h >> 16; - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); - let dl = d * (f32(ls) - 32.0); - for (var j: u32 = 0; j < 16; j++) { - let iqs = ib * 16 + j; - let qsb = get_byte(block.qs[iqs / 4], iqs % 4); - sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - src1_i += 16; - } - return sum; -} -#endif - -struct MulMatParams { - offset_src0: u32, // in elements/blocks - offset_src1: u32, // in elements/blocks - offset_dst: u32, // in elements/blocks - m: u32, - n: u32, - k: u32, - // all strides are in elements/blocks - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns -@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed) -@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns - -@group(0) @binding(3) var<uniform> params: MulMatParams; - -@compute @workgroup_size(256) -fn main(@builtin(local_invocation_id) local_id: vec3<u32>, - @builtin(workgroup_id) wg_id: vec3<u32>, - @builtin(num_workgroups) num_wg: vec3<u32>) { - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let global_idx = wg_linear * 256u + local_id.x; - - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_idx >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_idx / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_idx % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.m; // output row - let col = dst2_rem % params.m; // output column - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { - sum += multiply_add(src0_idx_base, src1_idx_base, i); - } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index eb2a8368f43..ed4a6b13bbf 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -25,6 +25,10 @@ fn store_shmem(val: f16, idx: u32) { } #endif // SCALAR +#define QUANT_SHMEM shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" + #ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { @@ -94,79 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_Q1_0 -#ifdef INIT_SRC0_SHMEM_Q4_0 +#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4) const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; +#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +#else const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +#endif const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; + let block_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; + let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - let tile_m = blck_idx / BLOCKS_K; + let tile_m = block_idx / BLOCKS_K; let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; + let block_k = block_idx % BLOCKS_K; let global_block_k = k_outer / BLOCK_SIZE + block_k; if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + +#ifdef INIT_SRC0_SHMEM_Q4_0 + let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; - } + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_0 - -#ifdef INIT_SRC0_SHMEM_Q4_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 20u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); +#elif INIT_SRC0_SHMEM_Q4_1 + let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -178,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_1 - -#ifdef INIT_SRC0_SHMEM_Q5_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 22u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -// tile_k is defined as 32u, so blocks_k ends up being 1 always -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q5_0 + let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u; let d = load_f16_at_src0(block_byte_base); let qh_packed = load_u32_at_src0(block_byte_base + 2u); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -229,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_0 - -#ifdef INIT_SRC0_SHMEM_Q5_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 24u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; +#elif INIT_SRC0_SHMEM_Q5_1 + let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u; - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); + let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); + let q_packed = load_u32_at_src0_aligned(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); @@ -280,241 +201,73 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_1 - -#ifdef INIT_SRC0_SHMEM_Q8_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 34u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q8_0 + let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte_i32(q_packed, k); - - let q_val = f16(q_byte) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; - } + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q8_0 - -#ifdef INIT_SRC0_SHMEM_Q8_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 36u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); +#elif INIT_SRC0_SHMEM_Q8_1 + let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d + m; shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q8_1 - -#ifdef INIT_SRC0_SHMEM_Q2_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 84u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - // Use standard thread layout instead of lane/row_group - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; +#elif INIT_SRC0_SHMEM_MXFP4 + let block_byte_base = src0_idx * 17u; + let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u); + let e = ldexp(1.0, i32(eu8) - 128); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); + } + } +#endif } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base + 80u); - let dmin = load_f16_at_src0(block_byte_base + 82u); - - // Decode the element at position k_in_block - let block_of_32 = k_in_block / 32u; - let pos_in_32 = k_in_block % 32u; - - let q_b_idx = (block_of_32 / 4u) * 32u; - let shift = (block_of_32 % 4u) * 2u; - let k = (pos_in_32 / 16u) * 16u; - let l = pos_in_32 % 16u; - - let is = k_in_block / 16u; - - let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); - let sc = get_byte(sc_packed, is % 4u); - - let dl = d * f16(sc & 0xFu); - let ml = dmin * f16(sc >> 4u); - - let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 3u; - - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; } } -#endif // INIT_SRC0_SHMEM_Q2_K +#endif -#ifdef INIT_SRC0_SHMEM_Q3_K +// k-quants +#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K) const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 110u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base + 108u); - - // Load and unpack scales - let kmask1: u32 = 0x03030303u; - let kmask2: u32 = 0x0f0f0f0fu; - - var scale_vals: array<u32, 4>; - for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); - } - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u); - scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u); - - // Load hmask and qs arrays - var hmask_vals: array<u32, 8>; - for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); - } - - var qs_vals: array<u32, 16>; - for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); - } - - let half = k_in_block / 128u; // 0 or 1 - let pos_in_half = k_in_block % 128u; // 0-127 - let shift_group = pos_in_half / 32u; // 0-3 - let pos_in_32 = pos_in_half % 32u; // 0-31 - let k_group = pos_in_32 / 16u; // 0 or 1 - let l = pos_in_32 % 16u; // 0-15 - - let q_b_idx = half * 32u; // 0 or 32 - let shift = shift_group * 2u; // 0, 2, 4, 6 - let k = k_group * 16u; // 0 or 16 - let is = k_in_block / 16u; // 0-15 - - // m increments every 32 elements across entire 256 element block - let m_shift = k_in_block / 32u; // 0-7 - let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128 - - let sc = get_byte(scale_vals[is / 4u], is % 4u); - let dl = d * (f16(sc) - 32.0); - - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; +const NQ = 4u; - let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u); - - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3u; - - let q_val = (f16(qs_val) - f16(hm)) * dl; - shmem[elem_idx] = q_val; - } +fn store_shmem_kquants(val: vec4<f16>, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; } -#endif // INIT_SRC0_SHMEM_Q3_K - -#ifdef INIT_SRC0_SHMEM_Q4_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 144u; +fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 { + return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u); +} fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; @@ -522,224 +275,232 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let global_k = k_outer + tile_k; if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); + store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx); continue; } - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - // Map k_in_block to loop structure: - // Outer loop over 64-element groups (alternating q_b_idx) - // Inner loop over 2 shifts per group - let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx) - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 +#ifdef INIT_SRC0_SHMEM_Q2_K + let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u; + let scales_byte_base = block_byte_base; + let qs_byte_base = block_byte_base + 16u; + let dm_byte_base = block_byte_base + 80u; + + let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(d_packed[0]); + let dmin = f16(d_packed[1]); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + // whole 2 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u), + ); + + let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block); + + let dl = d * f16(scale & 0xFu); + let ml = dmin * f16(scale >> 4u); + + store_shmem_kquants(qs_vec4 * dl - ml, elem_idx); +#elif INIT_SRC0_SHMEM_Q3_K + let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u; + let hmask_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 32u; + let scales_byte_base = block_byte_base + 96u; + + let d_all = load_f16_at_src0(block_byte_base + 108u); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + let hmask_block = pos_in_chunk; + let hmask_shift_phase = k_in_block / 32u; + + // low 2 bits (4 elems) + let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block); + let q_lo2_vec4 = vec4<f16>( + f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u) + ); + + // high 1 bit (4 elems) + let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk); + let q_hi1_vec4 = vec4<f16>( + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u)) + ); + + let q_vec4 = q_lo2_vec4 - q_hi1_vec4; + + let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu; + let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u; + let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0); + + store_shmem_kquants(dl * q_vec4, elem_idx); +#elif INIT_SRC0_SHMEM_Q4_K + let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qs_byte_base = block_byte_base + 16u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + // whole 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); - - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 0xFu; - - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; - } -} -#endif // INIT_SRC0_SHMEM_Q4_K - -#ifdef INIT_SRC0_SHMEM_Q5_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 176u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - - // The original loop processes elements in groups of 64 - // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] - // But u increments EVERY 32 elements (after each l loop) - let group_of_64 = k_in_block / 64u; // 0-3 - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 - - // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128) - let u_shift = k_in_block / 32u; // 0-7 - let u: u32 = 1u << u_shift; + store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q5_K + let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qh_byte_base = block_byte_base + 16u; + let qs_byte_base = block_byte_base + 48u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + let qh_block = k_in_block % 32u; + let qh_shift_phase = sub_block; + + // low 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_lo4_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); + + // high 1 bit (4 elems) + let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block); + let qh_vec4 = vec4<f16>( + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u)) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); - - let q_byte = get_byte(q_packed, q_idx % 4u); - - let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); - - let qh_byte = get_byte(qh_packed, l % 4u); - - let qs_val = (q_byte >> shift) & 0xFu; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - - let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml; - shmem[elem_idx] = q_val; - } -} - -#endif // INIT_SRC0_SHMEM_Q5_K - -#ifdef INIT_SRC0_SHMEM_Q6_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let half = k_in_block / 128u; - let pos_in_half = k_in_block % 128u; - let quarter = pos_in_half / 32u; - let l = pos_in_half % 32u; - - let ql_b_idx = half * 64u; - let qh_b_idx = half * 32u; - let sc_b_idx = half * 8u; - - // Load only ql13 word needed - let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); - let ql13_b = get_byte(ql13, 0u); - - // Load only ql24 word needed - let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); - let ql24_b = get_byte(ql24, 0u); - - // Load only qh word needed - let qh_flat = qh_b_idx + l; - let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); - let qh_b = get_byte(qh, 0u); - - let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); - let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); - let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0); - let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0); - - // Load only the scale word needed - let is = l / 16u; - let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); - let sc_val = get_byte_i32(sc, 0u); - - let d = load_f16_at_src0(block_byte_base + 208u); - - var q_val: f16; - if (quarter == 0u) { - q_val = q1; - } else if (quarter == 1u) { - q_val = q2; - } else if (quarter == 2u) { - q_val = q3; - } else { - q_val = q4; - } - - shmem[elem_idx] = d * f16(sc_val) * q_val; + store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q6_K + let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u; + let ql_byte_base = block_byte_base; + let qh_byte_base = block_byte_base + 128u; + let scales_byte_base = block_byte_base + 192u; + let d_byte_base = block_byte_base + 208u; + + let d = load_f16_at_src0(d_byte_base); + + let chunk = k_in_block / 128u; + let ql_pos_in_chunk = (k_in_block % 128u) % 64u; + let qh_pos_in_chunk = (k_in_block % 128u) % 32u; + let sub_block = k_in_block / 16u; + let ql_shift_phase = (k_in_block % 128u) / 64u; + let qh_shift_phase = (k_in_block % 128u) / 32u; + + // low 4 bits (4 elems) + let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk); + let ql_lo4_vec4 = vec4<u32>( + (ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu + ); + + // hi 2 bits (4 elems) + let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk); + let qh_hi2_vec4 = vec4<u32>( + ((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u, + ); + + let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0); + + let scale_byte = scales_byte_base + 1u * sub_block; + let scale_word = load_u32_at_src0_aligned(scale_byte); + let scale = get_byte_i32(scale_word, scale_byte & 3u); + + store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx); +#endif } } -#endif // INIT_SRC0_SHMEM_Q6_K +#endif // k-quants #ifdef INIT_SRC0_SHMEM_IQ4_NL const BLOCK_SIZE = 32u; @@ -1163,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_IQ3_S - -#ifdef INIT_SRC0_SHMEM_MXFP4 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 17u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); - let e = ldexp(1.0, i32(eu8) - 128); - - // store NQ(16) weights - for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - - let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; - let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); - } - } - } - } -} -#endif // INIT_SRC0_SHMEM_MXFP4 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl index d79d5f3f282..581e922709d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl @@ -21,35 +21,32 @@ var<workgroup> count:atomic<u32>; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, - @builtin(local_invocation_id) local_id: vec3<u32>, - @builtin(num_workgroups) num_wg: vec3<u32>) { + @builtin(local_invocation_id) local_id: vec3<u32>) { let thread_id = local_id.x; - let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup + let own_expert = wg_id.x; // the expert assigned to this workgroup - if (own_expert < params.n_expert) { - if (thread_id == 0u) { - atomicStore(&count, 0); - } + if (thread_id == 0u) { + atomicStore(&count, 0); + } - workgroupBarrier(); - - for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { - let row = i / params.n_expert_used; - let col = i % params.n_expert_used; - let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); - if (own_expert == expert) { - let pos = atomicAdd(&count, 1u); - let gathered_id = own_expert * params.n_tokens + pos; - global_gathered_expert_used[gathered_id] = col; - global_gathered_tokens[gathered_id] = row; - } + workgroupBarrier(); + + for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { + let row = i / params.n_expert_used; + let col = i % params.n_expert_used; + let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); + if (own_expert == expert) { + let pos = atomicAdd(&count, 1u); + let gathered_id = own_expert * params.n_tokens + pos; + global_gathered_expert_used[gathered_id] = col; + global_gathered_tokens[gathered_id] = row; } + } - workgroupBarrier(); + workgroupBarrier(); - if (thread_id == 0u) { - gathered_count_ids[own_expert] = atomicLoad(&count); - } + if (thread_id == 0u) { + gathered_count_ids[own_expert] = atomicLoad(&count); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index a194cf40468..f0a7fbd059a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -3,10 +3,18 @@ enable subgroups; #endif enable f16; +#ifdef MMVQ +requires packed_4x8_integer_dot_product; +#endif + #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" +#ifdef MMVQ +#include "mul_mat_vec_q_acc.tmpl" +#else #include "mul_mat_vec_acc.tmpl" +#endif struct MulMatParams { offset_src0: u32, @@ -28,9 +36,14 @@ struct MulMatParams { }; @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; + +#ifdef MMVQ +@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>; +#else @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; -@group(0) @binding(2) var<storage, read_write> dst: array<f32>; +#endif +@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 @group(0) @binding(3) var<uniform> params: MulMatParams; @@ -75,10 +88,15 @@ fn main( let src12_idx = dst2_idx; let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; +#ifdef MMVQ + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); +#else + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); +#endif #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 711c7e829d8..08753b9d643 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -436,7 +436,6 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src } #endif - #ifdef MUL_ACC_Q3_K #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl new file mode 100644 index 00000000000..3ef2f77ebe0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -0,0 +1,303 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +#ifdef LEGACY_QUANTS +#define BLOCK_SIZE 32 +#define THREADS_PER_BLOCK 4 +#elif K_QUANTS +#define BLOCK_SIZE 256 +#define THREADS_PER_BLOCK 16 +#endif + +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +#define Q8_BLOCK_SIZE 32 + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE_BYTES 18 +#define B_DS_TYPE vec2<f32> +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> { + let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id); + + return vec2<u32>( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE_BYTES 20 +#define B_DS_TYPE vec2<f32> +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> { + let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id); + + return vec2<u32>( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> vec2<f32> { + return vec2<f32>( + f32(load_f16_at_src0(block_byte_base)), + f32(load_f16_at_src0(block_byte_base + 2u)) + ); +} +fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE_BYTES 34 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)), + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1)) + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + src1q[block].qs[inner_id * 2u], + src1q[block].qs[inner_id * 2u + 1], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[block].d); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds); +} +#endif + +#ifdef LEGACY_QUANTS +fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 { + var row_sum = 0; + let a_repacked = repack_a(a_byte_base, b_inner_id); + + row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); + row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); + + return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); +} + +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let b_inner_id = thread_id % THREADS_PER_BLOCK; + let b_block_idx = src1q_idx_base + block; + + let b_repacked = repack_b_qs(b_block_idx, b_inner_id); + let b_ds = repack_b_dm(b_block_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE_BYTES 84 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> { + let ih2 = tid / 8u; + let phase = tid % 2u; + let iq4_idx = 2u * ih2 + phase; + let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx; + let qs_shift = tid & 6u; + return vec4<u32>( + (load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> { + let phase = tid % 2u; + return vec4<u32>( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[q8_block_idx].d); +} +fn get_dm(block_byte_base: u32) -> vec2<f32> { + return vec2<f32>( + f32(load_f16_at_src0(block_byte_base + 80u)), + f32(load_f16_at_src0(block_byte_base + 82u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> { + let scale_byte = block_byte_base + tid; + let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); + return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u)); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE_BYTES 144 +#define B_DS_TYPE vec2<f32> +fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> { + let iq4 = tid / 4u; + let phase = tid % 2u; + let nibble = (tid >> 1u) % 2u; + let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase; + let qs_shift = 4u * nibble; + return vec4<u32>( + (load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> { + let phase = tid % 2u; + return vec4<u32>( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[q8_block_idx].d), + f32(src1q[q8_block_idx].s), + ); +} +fn get_dm(block_byte_base: u32) -> vec2<f32> { + return vec2<f32>( + f32(load_f16_at_src0(block_byte_base + 0u)), + f32(load_f16_at_src0(block_byte_base + 2u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> { + let sc_m_idx = tid / 2u; + let scales_byte_base = block_byte_base + 4u; + let scales0_3 = load_u32_at_src0_aligned(scales_byte_base); + let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u); + let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u); + + let byte_idx = sc_m_idx & 3u; + let is_high = sc_m_idx >= 4u; + + let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu; + let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u); + let scale = f32(select(sc_low, sc_high, is_high)); + + let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu; + let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u); + let min_val = f32(select(mn_low, mn_high, is_high)); + + return vec2<f32>(scale, min_val); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +} +#endif + +#ifdef K_QUANTS +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + + for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { + let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl new file mode 100644 index 00000000000..d1da4608434 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl @@ -0,0 +1,21 @@ +#ifdef U32_DEQUANT_HELPERS +fn dequant_q4_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (QUANT_OUT_TYPE((q_byte >> 4) & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + let q_lo = (QUANT_OUT_TYPE(q_byte & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + QUANT_SHMEM[dst_idx + k] = q_lo; + QUANT_SHMEM[dst_idx + k + 16u] = q_hi; + } +} + +fn dequant_q8_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = QUANT_OUT_TYPE(q_byte) * scale; + QUANT_SHMEM[dst_idx + k] = q_val; + } +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl new file mode 100644 index 00000000000..b3f1fa04b80 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -0,0 +1,173 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +requires packed_4x8_integer_dot_product; + +#include "common_decls.tmpl" + +struct Params { + offset_src1: u32, + stride_12: u32, + stride_13: u32, + ne0: u32, + ne2: u32, + ne3: u32, +}; + +#define SRC1_TYPE vec4<SRC1_INNER_TYPE> + +@group(0) @binding(0) var<storage, read_write> src1: array<SRC1_TYPE>; +@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>; + +@group(0) @binding(2) var<uniform> params: Params; + +#ifdef USE_SUBGROUP_REDUCTION +fn cluster_max_8(v: f32) -> f32 { + var r = v; + r = max(r, subgroupShuffleXor(r, 1u)); + r = max(r, subgroupShuffleXor(r, 2u)); + r = max(r, subgroupShuffleXor(r, 4u)); + return r; +} + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) +fn cluster_add_i4x8(v: i32) -> i32 { + var r= v; + r += subgroupShuffleXor(r, 1u); + r += subgroupShuffleXor(r, 2u); + r += subgroupShuffleXor(r, 4u); + return r; +} +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION +#define CLUSTER_SIZE 8 + +var<workgroup> partial_amaxs: array<array<f32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>; +var<workgroup> partial_sums: array<array<i32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>; +#endif + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +) { + let thread_id = local_id.x; + let num_vec4 = params.ne0 / 4u; + + let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne2 * params.ne3; + + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + if (wg_linear >= total_batches) { + return; + } + + let src13_idx = wg_linear / (params.ne2 * wg_per_vec); + let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; + let src11_wg_idx = wg_linear % wg_per_vec; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let src1_idx_vec4_base = src1_idx_base / 4u; + + let blocks_per_row = params.ne0 / 32u; + let blocks_per_wg = (WG_SIZE * 4u) / 32u; + let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; + let qs_idx = thread_id % 8u; + + // reduction + var q4 = vec4<f32>(0.0); + var q4_quants = 0u; + var thread_amax = 0.0; + + let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; + let is_valid = src11_vec4_idx < num_vec4; + +#ifdef USE_SUBGROUP_REDUCTION + + var d = 0.0; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3])); + } + + d = cluster_max_8(thread_amax) / 127.0; + + if (is_valid) { + let id = select(0.0, 1.0 / d, d > 0.0); + q4_quants = pack4xI8(vec4<i32>(round(q4 * id))); + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + src1q[src1q_idx].qs[qs_idx] = q4_quants; + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u); + let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum))); + + if (is_valid) { + if (qs_idx == 0u) { + src1q[src1q_idx].s = s; + } + } +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION + + var d = 0.0; + let cluster_id = thread_id / 8u; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3])); + partial_amaxs[cluster_id][qs_idx] = thread_amax; + } + + workgroupBarrier(); + + if (is_valid) { + let amax = max( + max( + max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])), + max( + max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7])) + ); + + d = amax / 127.0; + let id = select(0.0f, 1.0f / d, d > 0.0f); + + q4_quants = pack4xI8(vec4<i32>(round(q4 * id))); + src1q[src1q_idx].qs[qs_idx] = q4_quants; + + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + + partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u); + + workgroupBarrier(); + + if (is_valid) { + if (qs_idx == 0u) { + let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3] + + partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]); + src1q[src1q_idx].s = f16(s); + } + } + +#endif +#endif + +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 3b70a876d70..6c76ed69e45 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -43,12 +43,14 @@ struct Params { var<storage, read_write> src: array<f32>; @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + var i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i >= params.ne) { return; } - - var i = gid.x; let i3 = i / (params.ne2 * params.ne1 * params.ne0); i = i % (params.ne2 * params.ne1 * params.ne0); let i2 = i / (params.ne1 * params.ne0); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 99e9192c71a..09f2f0eddb3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -71,7 +71,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { return; } - // getting the row from gid let elems_per_row = params.ne0 / VEC_SIZE; var i = gid.x / elems_per_row; @@ -104,6 +103,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - let col_idx = (gid.x % elems_per_row); - dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]); + let col_idx = gid.x % elems_per_row; + dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl new file mode 100644 index 00000000000..876e65b6ae1 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl @@ -0,0 +1,224 @@ +#ifdef DST_Q8_0 +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 34u +#define QS_WORDS 8u +#elif defined(DST_Q4_0) +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 18u +#define QS_WORDS 4u +#endif + +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> idx: array<u32>; + +@group(0) @binding(2) +#ifdef PAIR_BLOCKS +var<storage, read_write> dst: array<u32>; +#else +var<storage, read_write> dst: array<atomic<u32>>; +#endif + +#ifdef I64_IDX +@group(0) @binding(3) +var<storage, read_write> error: atomic<u32>; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in blocks + + // Strides (in elements / blocks) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var<uniform> params: Params; + +// if the quantization type is unaligned and there are an odd number of blocks per row, we need to store atomically +#ifndef PAIR_BLOCKS +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + loop { + let old = atomicLoad(&dst[word_idx]); + let merged = (old & ~mask) | (bits & mask); + let result = atomicCompareExchangeWeak(&dst[word_idx], old, merged); + if (result.exchanged) { + return; + } + } +} +#else +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + let old = dst[word_idx]; + dst[word_idx] = (old & ~mask) | (bits & mask); +} +#endif + +fn store_u16(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 2u) * 8u; + let mask = 0xFFFFu << shift; + merge_store_dst_word(word_idx, mask, (value & 0xFFFFu) << shift); +} + +fn store_u32(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 3u) * 8u; + + if (shift == 0u) { +#ifdef PAIR_BLOCKS + dst[word_idx] = value; +#else + atomicStore(&dst[word_idx], value); +#endif + return; + } + + let lo_mask = 0xFFFFFFFFu << shift; + let hi_mask = (1u << shift) - 1u; + merge_store_dst_word(word_idx, lo_mask, value << shift); + merge_store_dst_word(word_idx + 1u, hi_mask, value >> (32u - shift)); +} + +fn quantize_block_params(src_block: u32) -> vec2<f32> { +#ifdef DST_Q8_0 + var amax = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + amax = max(amax, abs(src[src_block + j])); + } + + let d = amax / 127.0; + let id = select(0.0, 1.0 / d, d > 0.0); + return vec2(d, id); +#elif defined(DST_Q4_0) + var amax = 0.0; + var max_val = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + let v = src[src_block + j]; + let av = abs(v); + if (amax < av) { + amax = av; + max_val = v; + } + } + + let d = max_val / -8.0; + let id = select(0.0, 1.0 / d, d != 0.0); + return vec2(d, id); +#endif +} + +fn quantize_block_word(src_block: u32, j: u32, id: f32) -> u32 { +#ifdef DST_Q8_0 + let base = src_block + j * 4u; + return (u32(i32(round(src[base + 0u] * id)) & 0xFF) << 0u) | + (u32(i32(round(src[base + 1u] * id)) & 0xFF) << 8u) | + (u32(i32(round(src[base + 2u] * id)) & 0xFF) << 16u) | + (u32(i32(round(src[base + 3u] * id)) & 0xFF) << 24u); +#elif defined(DST_Q4_0) + var packed_q = 0u; + for (var k: u32 = 0u; k < 4u; k++) { + let x0 = src[src_block + j * 4u + k] * id; + let x1 = src[src_block + 16u + j * 4u + k] * id; + let q0 = u32(clamp(i32(x0 + 8.5), 0, 15)); + let q1 = u32(clamp(i32(x1 + 8.5), 0, 15)); + packed_q |= (q0 & 0xFu) << (8u * k); + packed_q |= (q1 & 0xFu) << (8u * k + 4u); + } + return packed_q; +#endif +} + +fn quantize_block(src_block: u32, dst_word_idx: u32, block_byte_offset: u32) { + let params = quantize_block_params(src_block); + let d = params.x; + let id = params.y; + let packed_d = pack2x16float(vec2(d, 0.0)) & 0xFFFFu; + store_u16(dst_word_idx, block_byte_offset, 0u, packed_d); + + for (var j: u32 = 0u; j < QS_WORDS; j++) { + store_u32(dst_word_idx, block_byte_offset, 2u + j * 4u, quantize_block_word(src_block, j, id)); + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + let blocks_per_row = params.ne0 / BLOCK_SIZE; +#ifdef PAIR_BLOCKS + let blocks_per_invocation = 2u; +#else + let blocks_per_invocation = 1u; +#endif + let invocations_per_row = blocks_per_row / blocks_per_invocation; + let total_invocations = params.ne3 * params.ne2 * params.n_rows * invocations_per_row; + if (gid.x >= total_invocations) { + return; + } + + var i = gid.x / invocations_per_row; + let block_in_row = (gid.x % invocations_per_row) * blocks_per_invocation; + + let i_src3 = i / (params.ne2 * params.n_rows); + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + +#ifdef I64_IDX + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2u; + let idx_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1u]; + + if (idx_low_val != 0u) { + atomicStore(&error, 1u); + return; + } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif + + let dst_row_blocks = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + let src_block = src_row + block_in_row * BLOCK_SIZE; + let dst_block_byte = (dst_row_blocks + block_in_row) * BLOCK_BYTES; + + let dst_word_idx = dst_block_byte / 4u; +#ifdef PAIR_BLOCKS + quantize_block(src_block, dst_word_idx, 0u); + quantize_block(src_block + BLOCK_SIZE, dst_word_idx, BLOCK_BYTES); +#else + quantize_block(src_block, dst_word_idx, dst_block_byte & 3u); +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8e34e1c9ca0..cb342c47263 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (flat_i >= params.ne) { return; } - var i = gid.x; + var i = flat_i; let ne2 = params.ne2; #ifdef DIAG let ne1 = params.ne0; @@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { #ifdef INPLACE src[params.offset_src + src_idx] = res; #else - dst[params.offset_dst + gid.x] = res; + dst[params.offset_dst + flat_i] = res; #endif } diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6051d082003..3c33dcb11a0 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -88,7 +88,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int return true; } -static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, +static bool ggml_zendnn_gemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, const void * A, int64_t lda, const void * B, int64_t ldb, void * C, int64_t ldc, int Atype, int Btype, int Ctype) { @@ -200,7 +200,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i12 = 0; i12 < ne12; i12++) { const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m ne11, // n ne10, // k @@ -213,7 +213,7 @@ static void ggml_zendnn_compute_forward_mul_mat( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } } } @@ -355,7 +355,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } // batched gemm for all tokens in this expert - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m cne1, // n ne10, // k @@ -368,7 +368,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } // scatter output rows to destination diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 476c3079795..b43016c87d2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL", "IM2COL_BACK", "IM2COL_3D", + "COL2IM_1D", "CONV_2D", "CONV_3D", "CONV_2D_DW", @@ -1080,7 +1081,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1141,6 +1142,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col(x)", "im2col_back(x)", "im2col_3d(x)", + "col2im_1d(x)", "conv_2d(x)", "conv_3d(x)", "conv_2d_dw(x)", @@ -1190,7 +1192,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4541,6 +4543,41 @@ struct ggml_tensor * ggml_conv_1d_dw_ph( return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); } +// ggml_col2im_1d + +struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int s0, + int oc, + int p0) { + GGML_ASSERT(ggml_is_matrix(a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16); + GGML_ASSERT(s0 > 0); + GGML_ASSERT(oc > 0); + GGML_ASSERT(p0 >= 0); + + const int64_t K_OC = a->ne[0]; + const int64_t T_in = a->ne[1]; + const int64_t K = K_OC / oc; + const int64_t T_out = (T_in - 1) * s0 + K - 2 * p0; + + GGML_ASSERT(K_OC == K * oc); // a->ne[0] must be a whole number of oc blocks + GGML_ASSERT(K > 0 && T_out > 0); + + const int64_t ne[4] = { T_out, oc, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 2, ne); + + int32_t params[] = { s0, (int32_t)oc, (int32_t)p0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_COL2IM_1D; + result->src[0] = a; + + return result; +} + // ggml_conv_transpose_1d static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { @@ -5223,7 +5260,7 @@ static struct ggml_tensor * ggml_fill_impl( struct ggml_tensor * a, float c, bool inplace) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -6186,7 +6223,8 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state) { + struct ggml_tensor * state, + int64_t K) { GGML_ASSERT(ggml_is_contiguous_rows(q)); GGML_ASSERT(ggml_is_contiguous_rows(k)); GGML_ASSERT(ggml_is_contiguous_rows(v)); @@ -6210,15 +6248,18 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. - GGML_ASSERT(state->ne[0] == S_v * S_v * H); - GGML_ASSERT(state->ne[2] == n_seqs); - GGML_ASSERT(state->ne[3] == 1); - const int64_t K = state->ne[1]; + // state holds the initial state s0 only: [S_v, S_v, H, n_seqs]. K (snapshot slot count) is an op param. + GGML_ASSERT(state->ne[0] == S_v); + GGML_ASSERT(state->ne[1] == S_v); + GGML_ASSERT(state->ne[2] == H); + GGML_ASSERT(state->ne[3] == n_seqs); + GGML_ASSERT(K >= 1); const int64_t state_rows = K * S_v * n_seqs; const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + ggml_set_op_params_i32(result, 0, (int32_t) K); + result->op = GGML_OP_GATED_DELTA_NET; result->src[0] = q; result->src[1] = k; diff --git a/include/parakeet.h b/include/parakeet.h new file mode 100644 index 00000000000..d35aa870adb --- /dev/null +++ b/include/parakeet.h @@ -0,0 +1,342 @@ +#ifndef PARAKEET_H +#define PARAKEET_H + +#include "ggml.h" +#include "ggml-cpu.h" + +#include <stddef.h> +#include <stdint.h> +#include <stdbool.h> + +#ifdef __GNUC__ +# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define PARAKEET_DEPRECATED(func, hint) func +#endif + +#ifdef PARAKEET_SHARED +# ifdef _WIN32 +# ifdef PARAKEET_BUILD +# define PARAKEET_API __declspec(dllexport) +# else +# define PARAKEET_API __declspec(dllimport) +# endif +# else +# define PARAKEET_API __attribute__ ((visibility ("default"))) +# endif +#else +# define PARAKEET_API +#endif + +#define PARAKEET_SAMPLE_RATE 16000 +#define PARAKEET_HOP_LENGTH 160 + +#ifdef __cplusplus +extern "C" { +#endif + + struct parakeet_context; + struct parakeet_state; + struct parakeet_full_params; + + typedef int32_t parakeet_pos; + typedef int32_t parakeet_token; + typedef int32_t parakeet_seq_id; + + struct parakeet_context_params { + bool use_gpu; + int gpu_device; // CUDA device + }; + + typedef struct parakeet_token_data { + parakeet_token id; // the BPE subword ID (0-8191) + + int duration_idx; // index into the models durations array + int duration_value; // actual duration value + int frame_index; + + float p; + float plog; + + int64_t t0; + int64_t t1; + + bool is_word_start; + } parakeet_token_data; + + typedef struct parakeet_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } parakeet_model_loader; + + PARAKEET_API const char * parakeet_version(void); + + // Various functions for loading a ggml parakeet model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523) + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx); + + // Frees all allocated memory + PARAKEET_API void parakeet_free (struct parakeet_context * ctx); + PARAKEET_API void parakeet_free_state(struct parakeet_state * state); + PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params); + PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided parakeet context. + // Returns 0 on success + PARAKEET_API int parakeet_pcm_to_mel( + struct parakeet_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + PARAKEET_API int parakeet_pcm_to_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context. + // Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 128 + // Returns 0 on success + PARAKEET_API int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel); + + PARAKEET_API int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context. + // Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + PARAKEET_API int parakeet_encode( + struct parakeet_context * ctx, + int offset, + int n_threads); + + PARAKEET_API int parakeet_encode_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + int offset, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + PARAKEET_API int parakeet_tokenize( + struct parakeet_context * ctx, + const char * text, + parakeet_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0) + int parakeet_token_count(struct parakeet_context * ctx, const char * text); + + PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length + PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length + PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx); + + PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx); + + // Token logits obtained from the last call to parakeet_full/parakeet_chunk + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx); + PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token); + + PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len); + + // Special tokens + PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx); + + // Performance information from the default state. + struct parakeet_timings { + float sample_ms; + float encode_ms; + float decode_ms; + }; + PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx); + + // Print system information + PARAKEET_API const char * parakeet_print_system_info(void); + + // Available sampling strategies + enum parakeet_sampling_strategy { + PARAKEET_SAMPLING_GREEDY, + }; + + // Token callback. + // Called for each new predicted token. + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_token_callback)( + struct parakeet_context * ctx, + struct parakeet_state * state, + const parakeet_token_data * token_data, + void * user_data); + + // Text segment callback + // Called on every newly generated text segment + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data); + + // Parameters for the parakeet_full() function + // If you change the order or add new parameters, make sure to update the default values in parakeet.cpp: + // parakeet_full_default_params() + struct parakeet_full_params { + enum parakeet_sampling_strategy strategy; + + int n_threads; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool no_context; // do not use past transcription (if any) as context + + int audio_ctx; // overwrite the audio context size (0 = use default) + + // called for every newly generated text segment + parakeet_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called for every newly generated token + parakeet_new_token_callback new_token_callback; + void * new_token_callback_user_data; + + // called on each progress update + parakeet_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + parakeet_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params() + PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void); + PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void); + + PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy); + PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + PARAKEET_API int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + PARAKEET_API int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Process a single chunk of audio data that fits within the model's audio context window. + // This is more efficient than parakeet_full() for short audio clips. + PARAKEET_API int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Number of generated text segments + PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx); + PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state); + + // Get the start and end time of the specified segment + PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment); + + PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment); + + // Get the text of the specified segment + PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment); + PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment); + + // Get number of tokens in the specified segment + PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token); + + // Get the token id of the specified token in the specified segment + PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Control logging output; default behavior is to print to stderr + + PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/media/matmul.png b/media/matmul.png new file mode 100644 index 00000000000..786a20492c0 Binary files /dev/null and b/media/matmul.png differ diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py new file mode 100755 index 00000000000..2d6a6d01554 --- /dev/null +++ b/models/convert-parakeet-to-ggml.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Convert Parakeet TDT model from NeMo format to ggml format +# +# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32] +# +# The NeMo file is a tar archive containing: +# - model_weights.ckpt (PyTorch checkpoint) +# - model_config.yaml (model configuration) +# - tokenizer files +# +# This script extracts the NeMo archive, loads the model weights and configuration, +# and saves them in ggml format compatible with whisper.cpp. +# + +import torch +import argparse +import io +import os +import sys +import struct +import tarfile +import tempfile +import shutil +import yaml +import numpy as np +from pathlib import Path +from typing import Optional + +def hz_to_mel(freq): + return 2595.0 * np.log10(1.0 + freq / 700.0) + +def mel_to_hz(mel): + return 700.0 * (10.0**(mel / 2595.0) - 1.0) + +def extract_nemo_archive(nemo_path, extract_dir): + print(f"Extracting {nemo_path} to {extract_dir}") + with tarfile.open(nemo_path, 'r') as tar: + tar.extractall(path=extract_dir) + print("Extraction complete") + +def load_model_config(config_path): + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config + +def load_tokenizer(extract_dir, config): + tokenizer_model_path = None + tokenizer_vocab_path = None + + for file in os.listdir(extract_dir): + if file.endswith('_tokenizer.model'): + tokenizer_model_path = os.path.join(extract_dir, file) + elif file.endswith('tokenizer.vocab'): + tokenizer_vocab_path = os.path.join(extract_dir, file) + + if not tokenizer_model_path: + raise FileNotFoundError("Tokenizer model file not found") + + if not tokenizer_vocab_path: + raise FileNotFoundError("Tokenizer vocab file not found") + + tokens = {} + with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + parts = line.strip().split('\t') + if len(parts) >= 1: + token = parts[0] + tokens[token.encode('utf-8')] = idx + + print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}") + + if len(tokens) != 8192: + print(f"WARNING: Expected 8192 tokens, got {len(tokens)}") + + return tokens + +def write_tensor(fout, name, data, use_f16=True, force_f32=False): + if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1: + data = data.reshape(1, -1, 1, 1) + print(f" Reshaped conv bias {name} to {data.shape}") + + n_dims = len(data.shape) + + ftype = 1 if use_f16 and not force_f32 else 0 + if force_f32: + data = data.astype(np.float32) + elif use_f16: + if n_dims < 2 or 'bias' in name or 'norm' in name or \ + ('pre_encode.conv' in name and n_dims == 4) or \ + 'depthwise_conv.weight' in name: + data = data.astype(np.float32) + ftype = 0 + else: + data = data.astype(np.float16) + else: + data = data.astype(np.float32) + + dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)] + print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}") + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + + data.tofile(fout) + +def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None): + nemo_path = Path(nemo_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Create temporary directory for extraction + with tempfile.TemporaryDirectory() as temp_dir: + extract_nemo_archive(nemo_path, temp_dir) + + config_path = os.path.join(temp_dir, 'model_config.yaml') + config = load_model_config(config_path) + + print("Model configuration:") + print(f" Sample rate: {config['sample_rate']}") + print(f" Encoder layers: {config['encoder']['n_layers']}") + print(f" Encoder d_model: {config['encoder']['d_model']}") + print(f" Mel features: {config['preprocessor']['features']}") + + weights_path = os.path.join(temp_dir, 'model_weights.ckpt') + print(f"\nLoading model weights from {weights_path}") + checkpoint = torch.load(weights_path, map_location='cpu') + + # Extract state dict + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + print(f"Loaded {len(state_dict)} tensors") + + # Load tokenizer + print("\nLoading tokenizer...") + tokens = load_tokenizer(temp_dir, config) + print(f"Loaded {len(tokens)} tokens") + + # Prepare hyperparameters for the Parakeet ggml format. + hparams = { + 'n_audio_ctx': 5000, + 'n_audio_state': config['encoder']['d_model'], + 'n_audio_head': config['encoder']['n_heads'], + 'n_audio_layer': config['encoder']['n_layers'], + 'n_mels': config['preprocessor']['features'], + 'n_fft': config['preprocessor']['n_fft'], + 'subsampling_factor': config['encoder']['subsampling_factor'], + 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], + 'n_conv_kernel': config['encoder']['conv_kernel_size'], + + 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], + 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], + 'n_vocab': config['decoder']['vocab_size'], + 'n_tdt_durations': config['model_defaults']['num_tdt_durations'], + 'n_max_tokens': config['decoding']['greedy']['max_symbols'], + } + + print("\nGGML hyperparameters:") + for key, value in hparams.items(): + print(f" {key}: {value}") + + # Create output file + if out_name: + fname_out = output_dir / out_name + else: + fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin") + print(f"\nWriting to {fname_out}") + + with open(fname_out, 'wb') as fout: + # Write magic number + fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex + + # Write hyperparameters + fout.write(struct.pack("i", hparams['n_vocab'])) + fout.write(struct.pack("i", hparams['n_audio_ctx'])) + fout.write(struct.pack("i", hparams['n_audio_state'])) + fout.write(struct.pack("i", hparams['n_audio_head'])) + fout.write(struct.pack("i", hparams['n_audio_layer'])) + fout.write(struct.pack("i", hparams['n_mels'])) + fout.write(struct.pack("i", 1 if use_f16 else 0)) + fout.write(struct.pack("i", hparams['n_fft'])) + fout.write(struct.pack("i", hparams['subsampling_factor'])) + fout.write(struct.pack("i", hparams['n_subsampling_channels'])) + fout.write(struct.pack("i", hparams['n_conv_kernel'])) + fout.write(struct.pack("i", hparams['n_pred_dim'])) + fout.write(struct.pack("i", hparams['n_pred_layers'])) + fout.write(struct.pack("i", hparams['n_tdt_durations'])) + fout.write(struct.pack("i", hparams['n_max_tokens'])) + + # Extract mel filterbank from model + fb_key = None + for key in state_dict.keys(): + if 'featurizer.fb' in key or 'filterbank' in key.lower(): + fb_key = key + break + + if not fb_key: + print("\nERROR: Mel filterbank not found in model!") + print("Expected tensor with 'featurizer.fb' or 'filterbank' in name") + print("\nAvailable preprocessor tensors:") + for key in sorted(state_dict.keys()): + if 'preprocessor' in key or 'featurizer' in key: + print(f" {key}: {state_dict[key].shape}") + raise ValueError("Mel filterbank tensor not found in model") + + print(f"\nUsing model's mel filterbank from: {fb_key}") + mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32) + print(f" Filterbank shape: {mel_filters.shape}") + print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}") + print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}") + print(f" First row sum: {mel_filters[0].sum():.6f}") + + if len(mel_filters.shape) != 2: + raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}") + + n_mels, n_freqs = mel_filters.shape + fout.write(struct.pack("i", n_mels)) # n_mel + fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins) + + # Write mel filterbank + for i in range(n_mels): + for j in range(n_freqs): + fout.write(struct.pack("f", mel_filters[i, j])) + + # Extract window function from model + window_key = None + for key in state_dict.keys(): + if 'featurizer.window' in key or 'preproc' in key and 'window' in key: + window_key = key + break + + if not window_key: + print("\nERROR: Window function not found in model!") + print("Expected tensor with 'featurizer.window' in name") + raise ValueError("Window function tensor not found in model") + + print(f"\nUsing model's window function from: {window_key}") + window = state_dict[window_key].squeeze().numpy().astype(np.float32) + print(f" Window shape: {window.shape}") + print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}") + print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}") + print(f" Window sum: {window.sum():.6f}") + + if len(window.shape) != 1: + raise ValueError(f"Expected 1D window, got shape {window.shape}") + + n_window = window.shape[0] + fout.write(struct.pack("i", n_window)) + + # Write window function + for i in range(n_window): + fout.write(struct.pack("f", window[i])) + + # Write TDT durations + tdt_durations = config['model_defaults']['tdt_durations'] + if len(tdt_durations) != hparams['n_tdt_durations']: + raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}") + + for duration in tdt_durations: + fout.write(struct.pack("I", duration)) + + fout.write(struct.pack("i", len(tokens))) + for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]): + fout.write(struct.pack("i", len(token_bytes))) + fout.write(token_bytes) + + # Pre-collect prediction LSTM input-hidden biases so they can be + # folded into the hidden-hidden bias during the main write loop. + lstm_prefix = 'decoder.prediction.dec_rnn.lstm' + pred_bias_ih = {} + for key, t in state_dict.items(): + if f'{lstm_prefix}.bias_ih_l' in key: + layer_idx = int(key.rsplit('bias_ih_l', 1)[1]) + pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32) + + print("\nConverting model weights...") + for name, tensor in state_dict.items(): + # Skip the filterbank and window - already written in preprocessing section + if name == fb_key: + continue + if name == window_key: + continue + + # bias_ih is folded into bias_hh below; skip writing it separately + if f'{lstm_prefix}.bias_ih_l' in name: + continue + + # Don't squeeze Conv2d weights - they need to preserve all 4 dimensions + if 'conv' in name and 'weight' in name and len(tensor.shape) == 4: + data = tensor.numpy() + else: + data = tensor.squeeze().numpy() + + # For prediction LSTM weights/biases: + # Fold bias_ih into bias_hh (bias_ih already skipped above). + # Reorder gates (input, forget, cell, output) from PyTorch layout + # [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs + # (i, f, o) are contiguous. + if name.startswith(f'{lstm_prefix}.'): + if f'{lstm_prefix}.bias_hh_l' in name: + layer_idx = int(name.rsplit('bias_hh_l', 1)[1]) + data = data.astype(np.float32) + pred_bias_ih[layer_idx] + name = name.replace('bias_hh_l', 'bias_h_l') + h = data.shape[0] // 4 + data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + write_tensor(fout, name, data, use_f16=use_f16) + + print(f"\nConversion complete!") + print(f"Output file: {fname_out}") + print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert Parakeet TDT model from NeMo format to ggml format' + ) + parser.add_argument('--model', type=str, required=True, + help='Path to Parakeet .nemo model file') + parser.add_argument('--out-dir', type=str, required=True, + help='Directory to write ggml model file') + parser.add_argument('--use-f32', action='store_true', default=False, + help='Use f32 instead of f16 (default: f16)') + parser.add_argument('--out-name', type=str, default=None, + help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)') + + args = parser.parse_args() + + if not os.path.exists(args.model): + print(f"Error: {args.model} not found") + sys.exit(1) + + use_f16 = not args.use_f32 + convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name) diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index 66827b6d420..7cf07754a89 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -8,10 +8,19 @@ from typing import Dict from typing import Optional from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase -from coremltools.models.neural_network.quantization_utils import quantize_weights from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions from whisper import load_model + +def _str_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true", "1", "yes"): + return True + if v.lower() in ("false", "0", "no"): + return False + raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'") + # Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues. # The Whisper implementation expects a specific behavior from # torch.nn.functional.scaled_dot_product_attention that differs between PyTorch @@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False): inputs=[ct.TensorType(name="logmel_data", shape=input_shape)], outputs=[ct.TensorType(name="output")], compute_units=ct.ComputeUnit.ALL, + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model def convert_decoder(hparams, model, quantize=False): @@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False): ct.TensorType(name="token_data", shape=tokens_shape, dtype=int), ct.TensorType(name="audio_data", shape=audio_shape) ], + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True) - parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) - parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) - parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) + parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False) + parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False) + parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False) args = parser.parse_args() if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]: diff --git a/models/download-ggml-model.sh b/models/download-ggml-model.sh index f1394e98484..0539c8afb3d 100755 --- a/models/download-ggml-model.sh +++ b/models/download-ggml-model.sh @@ -120,7 +120,13 @@ fi if [ -x "$(command -v wget2)" ]; then wget2 --no-config --progress bar -O ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v curl)" ]; then - curl -L --output ggml-"$model".bin $src/$pfx-"$model".bin + curl -L --fail \ + --retry 5 \ + --retry-delay 5 \ + --retry-all-errors \ + --retry-connrefused \ + ${HF_TOKEN:+--header "Authorization: Bearer $HF_TOKEN"} \ + --output ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v wget)" ]; then wget --no-config --quiet --show-progress -O ggml-"$model".bin $src/$pfx-"$model".bin else diff --git a/models/for-tests-ggml-parakeet-tdt.bin b/models/for-tests-ggml-parakeet-tdt.bin new file mode 100644 index 00000000000..8b1dda1feba Binary files /dev/null and b/models/for-tests-ggml-parakeet-tdt.bin differ diff --git a/models/generate-parakeet-test-model.py b/models/generate-parakeet-test-model.py new file mode 100755 index 00000000000..192a96ce627 --- /dev/null +++ b/models/generate-parakeet-test-model.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +import struct +import sys +import numpy as np +from pathlib import Path + +def write_tensor(fout, name, data): + n_dims = len(data.shape) + data = data.astype(np.float32) + ftype = 0 # GGML_TYPE_F32 + + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + data.tofile(fout) + +def generate(output_path): + rng = np.random.default_rng(42) + + hparams = { + 'n_vocab': 10, + 'n_audio_ctx': 3200, + 'n_audio_state': 8, + 'n_audio_head': 2, + 'n_audio_layer': 1, + 'n_mels': 16, + 'ftype': 0, + 'n_fft': 64, + 'subsampling_factor': 8, + 'n_subsampling_channels': 4, + 'n_conv_kernel': 3, + 'n_pred_dim': 8, + 'n_pred_layers': 1, + 'n_tdt_durations': 2, + 'n_max_tokens': 5, + } + + n_vocab = hparams['n_vocab'] + n_state = hparams['n_audio_state'] + n_head = hparams['n_audio_head'] + n_layer = hparams['n_audio_layer'] + n_mels = hparams['n_mels'] + n_fft = hparams['n_fft'] + n_sub_fac = hparams['subsampling_factor'] + n_sub_ch = hparams['n_subsampling_channels'] + n_conv_ker = hparams['n_conv_kernel'] + dec_dim = hparams['n_pred_dim'] + n_pred_l = hparams['n_pred_layers'] + n_tdt = hparams['n_tdt_durations'] + + n_pre_enc = (n_mels // n_sub_fac) * n_sub_ch + n_head_dim = n_state // n_head + n_pred_embed = n_vocab + 1 + n_lstm_gates = 4 * dec_dim + n_joint_out = n_vocab + n_tdt + 1 + n_freqs = n_fft // 2 + 1 + + def f32(*shape): + return rng.standard_normal(shape).astype(np.float32) + + with open(output_path, 'wb') as fout: + fout.write(struct.pack("I", 0x67676d6c)) + + for key in ['n_vocab', + 'n_audio_ctx', + 'n_audio_state', + 'n_audio_head', + 'n_audio_layer', + 'n_mels', + 'ftype', + 'n_fft', + 'subsampling_factor', + 'n_subsampling_channels', + 'n_conv_kernel', + 'n_pred_dim', + 'n_pred_layers', + 'n_tdt_durations', + 'n_max_tokens']: + fout.write(struct.pack("i", hparams[key])) + + fout.write(struct.pack("i", n_mels)) + fout.write(struct.pack("i", n_freqs)) + f32(n_mels, n_freqs).tofile(fout) + + fout.write(struct.pack("i", n_fft)) + f32(n_fft).tofile(fout) + + for d in range(n_tdt): + fout.write(struct.pack("I", d)) + + tokens = ['<unk>', '<s>', '</s>'] + [chr(ord('a') + i) for i in range(n_vocab - 3)] + assert len(tokens) == n_vocab + fout.write(struct.pack("i", n_vocab)) + for tok in tokens: + tok_bytes = tok.encode('utf-8') + fout.write(struct.pack("i", len(tok_bytes))) + fout.write(tok_bytes) + + write_tensor(fout, "encoder.pre_encode.out.weight", f32(n_state, n_pre_enc)) + write_tensor(fout, "encoder.pre_encode.out.bias", f32(n_state)) + + write_tensor(fout, "encoder.pre_encode.conv.0.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.0.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.2.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.2.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.3.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.3.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.5.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.5.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.6.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.6.bias", f32(1, n_sub_ch, 1, 1)) + + for i in range(n_layer): + p = f"encoder.layers.{i}" + + write_tensor(fout, f"{p}.norm_feed_forward1.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward1.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_conv.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_conv.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.pointwise_conv1.weight", f32(2*n_state, n_state)) + write_tensor(fout, f"{p}.conv.depthwise_conv.weight", f32(n_state, n_conv_ker)) + write_tensor(fout, f"{p}.conv.batch_norm.weight", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_mean", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_var", np.abs(f32(n_state))) + num_batches = np.zeros(1, dtype=np.int32) + write_tensor(fout, f"{p}.conv.batch_norm.num_batches_tracked", num_batches) + write_tensor(fout, f"{p}.conv.pointwise_conv2.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_self_att.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_self_att.bias", f32(n_state)) + + write_tensor(fout, f"{p}.self_attn.pos_bias_u", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.pos_bias_v", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.linear_q.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_k.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_v.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_out.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_pos.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_feed_forward2.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward2.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_out.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_out.bias", f32(n_state)) + + write_tensor(fout, "decoder.prediction.embed.weight", f32(n_pred_embed, dec_dim)) + + def reorder_gates(data): + h = data.shape[0] // 4 + return np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + for i in range(n_pred_l): + base = f"decoder.prediction.dec_rnn.lstm" + write_tensor(fout, f"{base}.weight_ih_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.weight_hh_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.bias_h_l{i}", reorder_gates(f32(n_lstm_gates) + f32(n_lstm_gates))) + + write_tensor(fout, "joint.pred.weight", f32(dec_dim, dec_dim)) + write_tensor(fout, "joint.pred.bias", f32(dec_dim)) + write_tensor(fout, "joint.enc.weight", f32(dec_dim, n_state)) + write_tensor(fout, "joint.enc.bias", f32(dec_dim)) + write_tensor(fout, "joint.joint_net.2.weight", f32(n_joint_out, dec_dim)) + write_tensor(fout, "joint.joint_net.2.bias", f32(n_joint_out)) + + size = Path(output_path).stat().st_size + print(f"Generated {output_path} ({size / 1024:.1f} KB)") + +if __name__ == '__main__': + output = sys.argv[1] if len(sys.argv) > 1 else 'models/for-tests-ggml-parakeet-tdt.bin' + generate(output) diff --git a/models/requirements-parakeet.txt b/models/requirements-parakeet.txt new file mode 100644 index 00000000000..5239ae0af5d --- /dev/null +++ b/models/requirements-parakeet.txt @@ -0,0 +1,3 @@ +torch +numpy +pyyaml diff --git a/scripts/quantize-parakeet.sh b/scripts/quantize-parakeet.sh new file mode 100755 index 00000000000..7816696bfcb --- /dev/null +++ b/scripts/quantize-parakeet.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +build_dir=build +modelname=ggml-parakeet-tdt-0.6b-v3 +model=models/${modelname}-f32.bin +cmd=parakeet-quantize + +cmake --build ${build_dir} --target $cmd -j 12 + +${build_dir}/bin/${cmd} $model models/${modelname}-q8_0.bin q8_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_0.bin q4_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_k.bin q4_k +${build_dir}/bin/${cmd} $model models/${modelname}-q2_k.bin q2_k diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index a4f87b2b9ae..87d353ef452 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -e705c5fed490514458bdd2eaddc43bd098fcce9b +3af5f5760e19a96427f5f7a93b79cbdf3d4b265b diff --git a/scripts/upload-parakeet.py b/scripts/upload-parakeet.py new file mode 100644 index 00000000000..3644bec8bd3 --- /dev/null +++ b/scripts/upload-parakeet.py @@ -0,0 +1,157 @@ +import argparse +import os +from huggingface_hub import HfApi, create_repo + +USER_NAME = "ggml-org" +REPO_ID = f"{USER_NAME}/parakeet-GGUF" + +MODELS = { + "f32": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f32.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f32.bin", + "description": "Full precision (F32)", + }, + "f16": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f16.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f16.bin", + "description": "Half precision (F16)", + }, + "q8_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "description": "8-bit quantized (Q8_0)", + }, + "q4_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "description": "4-bit quantized (Q4_0)", + }, + "q4_k": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "description": "4-bit K-quantized (Q4_k)", + }, +} + +def build_model_card(uploaded_variants): + lines = [ + f"---", + f"license: mit", + f"base_model: nvidia/parakeet-tdt-0.6b-v3", + f"tags:", + f"- gguf", + f"- asr", + f"---", + f"", + f"# Parakeet TDT 0.6B v3 (GGUF)", + f"", + f"GGUF conversions of [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) for use with [whisper.cpp](https://github.com/ggml-org/whisper.cpp).", + f"", + f"## Available files", + f"", + ] + + for key, m in MODELS.items(): + if key in uploaded_variants: + lines.append(f"- `{m['remote_name']}` — {m['description']}") + + lines += [ + f"", + f"## Usage", + f"", + f"Build parakeet-cli:", + f"```console", + f"git clone https://github.com/ggml-org/whisper.cpp.git", + f"cd whisper.cpp", + f"cmake -B build -S .", + f"cmake --build build --target parakeet-cli -j $(nproc)", + f"```", + f"", + f"Download a model (e.g. Q8_0):", + f"```console", + f"hf download {REPO_ID} {MODELS['q8_0']['remote_name']} --local-dir models", + f"```", + f"", + f"Run:", + f"```console", + f"./build/bin/parakeet-cli -m models/{MODELS['q8_0']['remote_name']} -f samples/jfk.wav", + f"```", + f"", + ] + + return "\n".join(lines) + + +def upload_variant(api, key): + m = MODELS[key] + local_path = m["local_path"] + + if not os.path.exists(local_path): + print(f" Skipping {key}: {local_path} not found") + return False + + print(f" Uploading {m['remote_name']} ({m['description']})...") + api.upload_file( + path_or_fileobj=local_path, + path_in_repo=m["remote_name"], + repo_id=REPO_ID, + repo_type="model", + commit_message=f"Upload {m['remote_name']}", + ) + return True + + +def main(): + parser = argparse.ArgumentParser(description="Upload parakeet GGUF models to Hugging Face") + parser.add_argument( + "variants", + nargs="*", + default=None, + metavar="{" + ",".join(MODELS.keys()) + "}", + help="Model variants to upload (default: all)", + ) + parser.add_argument( + "--no-model-card", + action="store_true", + help="Skip updating the model card README", + ) + args = parser.parse_args() + + api = HfApi() + create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True) + + variants = args.variants if args.variants else list(MODELS.keys()) + + unknown = [v for v in variants if v not in MODELS] + if unknown: + parser.error(f"unknown variant(s): {', '.join(unknown)} (choose from {', '.join(MODELS.keys())})") + + uploaded = [] + for key in variants: + if upload_variant(api, key): + uploaded.append(key) + + if not uploaded: + print("No models were uploaded.") + return + + if not args.no_model_card: + print("Updating model card...") + existing = [k for k in MODELS if k in uploaded or + any(f.rfilename == MODELS[k]["remote_name"] + for f in api.list_repo_files(REPO_ID, repo_type="model") + if hasattr(f, "rfilename"))] + card = build_model_card(existing if existing else uploaded) + api.upload_file( + path_or_fileobj=card.encode(), + path_in_repo="README.md", + repo_id=REPO_ID, + repo_type="model", + commit_message="Update README.md", + ) + + print(f"\nDone. Repository: https://huggingface.co/{REPO_ID}") + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3a09c7b9157..839d1775207 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -118,12 +118,27 @@ add_library(whisper whisper.cpp ) +add_library(parakeet + ../include/parakeet.h + parakeet-arch.h + parakeet.cpp + ) + +target_include_directories(parakeet PUBLIC . ../include) +target_compile_features (parakeet PUBLIC cxx_std_11) +target_link_libraries(parakeet PUBLIC ggml Threads::Threads) + # Set the version numbers set_target_properties(whisper PROPERTIES VERSION ${PROJECT_VERSION} SOVERSION ${SOVERSION} ) +set_target_properties(parakeet PROPERTIES + VERSION ${PROJECT_VERSION} + SOVERSION ${SOVERSION} +) + target_include_directories( whisper PUBLIC @@ -133,12 +148,17 @@ target_compile_features (whisper PUBLIC cxx_std_11) # don't bump if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN) + set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN) endif() if (WHISPER_EXTRA_FLAGS) target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS}) endif() +if (PARAKEET_EXTRA_FLAGS) + target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS}) +endif() + find_package(Threads REQUIRED) target_link_libraries(whisper PUBLIC ggml Threads::Threads) @@ -157,4 +177,7 @@ endif() if (BUILD_SHARED_LIBS) set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD) + + set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD) endif() diff --git a/src/parakeet-arch.h b/src/parakeet-arch.h new file mode 100644 index 00000000000..3407a95c9c7 --- /dev/null +++ b/src/parakeet-arch.h @@ -0,0 +1,188 @@ +#pragma once + +#include "ggml.h" + +#include <map> + +enum parakeet_tensor { + // Encoder pre_encode + PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, + + // Encoder layers (per-layer) + PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, + PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, + PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_BIAS, + PARAKEET_TENSOR_ENC_CONV_BN_MEAN, + PARAKEET_TENSOR_ENC_CONV_BN_VAR, + PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, + PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, + PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, + PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, + + // Prediction network + PARAKEET_TENSOR_PRED_EMBED_WEIGHT, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_H, + + // Joint network + PARAKEET_TENSOR_JOINT_PRED_WEIGHT, + PARAKEET_TENSOR_JOINT_PRED_BIAS, + PARAKEET_TENSOR_JOINT_ENC_WEIGHT, + PARAKEET_TENSOR_JOINT_ENC_BIAS, + PARAKEET_TENSOR_JOINT_NET_WEIGHT, + PARAKEET_TENSOR_JOINT_NET_BIAS, +}; + +static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"}, + + // Encoder layers (use %d for layer number) + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"}, +}; + +static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD}, + + // Encoder layers + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD}, +}; diff --git a/src/parakeet.cpp b/src/parakeet.cpp new file mode 100644 index 00000000000..b5da73e985c --- /dev/null +++ b/src/parakeet.cpp @@ -0,0 +1,3838 @@ +#include "parakeet.h" +#include "parakeet-arch.h" + +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include <atomic> +#include <algorithm> +#include <cassert> +#include <cfloat> +#define _USE_MATH_DEFINES +#include <cmath> +#include <climits> +#include <cstdarg> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <functional> +#include <cctype> +#include <map> +#include <random> +#include <set> +#include <string> +#include <thread> +#include <vector> + +#ifdef _MSC_VER +#include <codecvt> +#endif + +#if defined(PARAKEET_BIG_ENDIAN) +template<typename T> +static T byteswap(T value) { + T value_swapped; + char * source = reinterpret_cast<char *>(&value); + char * target = reinterpret_cast<char *>(&value_swapped); + int size = sizeof(T); + for (int i = 0; i < size; i++) { + target[size - 1 - i] = source[i]; + } + return value_swapped; +} + +template<typename T> +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast<T *>(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data<int16_t>(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data<ggml_fp16_t>(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data<int32_t>(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data<float>(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +PARAKEET_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal (ggml_log_level level, const char * format, ...); +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define PARAKEET_DEBUG + +#if defined(PARAKEET_DEBUG) +#define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define PARAKEET_LOG_DEBUG(...) +#endif + +#define PARAKEET_ASSERT(x) \ + do { \ + if (!(x)) { \ + PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#define PARAKEET_MAX_NODES 8192 + +// Threshold for when local attention should be used. +// 8192 frames x 80ms = 655 s (about 10.9 mins) +static constexpr int PARAKEET_LOCAL_ATTN_THRESHOLD = 8192; +// Window of context in each director of the current token. +// 128 frames * 80ms = 10.24 s +static constexpr int PARAKEET_LOCAL_ATTN_WINDOW = 128; + +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector<char> buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph * graph, + int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data); + } + + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend.get(), n_threads); + } + + return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS; +} + +static bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph, + int n_threads, + bool sched_reset = true) { + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + + auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) { + fn_set_n_threads(backend, n_threads); + } + } + + const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); + + if (!t || sched_reset) { + ggml_backend_sched_reset(sched); + } + + return t; +} + +// TODO: move these functions to ggml-base with support for ggml-backend? + + +struct parakeet_mel { + int n_len = 0; + int n_len_org = 0; + int n_mel = 0; + + std::vector<float> data; +}; + +struct parakeet_filters { + int32_t n_mel = 0; + int32_t n_fb = 0; // number of frequency bins + + std::vector<float> data; +}; + +struct parakeet_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 8192; + size_t max_token_length = 0; + + std::map<token, id> token_to_id; + std::map<id, token> id_to_token; + + id token_unk; + id token_bos; + id token_blank; + id token_eos; +}; + +struct parakeet_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector<parakeet_token_data> tokens; +}; + +struct parakeet_batch { + int32_t n_tokens; + + parakeet_token * token; + int32_t * i_time; // index of the audio frame + parakeet_pos * pos; + int32_t * n_seq_id; // always 1, here for consistency with llama.cpp + parakeet_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +// ggml_backend_sched wrapper for parakeet usage +struct parakeet_sched { + ggml_backend_sched_t sched = nullptr; + + std::vector<uint8_t> meta; +}; + +// TODO: Find out is there a multiple version types. It is not yet clear to me +// at this point. +enum parakeet_arch { + PARAKEET_ARCH_UNKNOWN = 0, + PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T) +}; + +struct parakeet_hparams { + int32_t n_vocab = 8192; + int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input + int32_t n_audio_state = 1024; + int32_t n_audio_head = 8; + int32_t n_audio_layer = 24; + int32_t n_mels = 128; + int32_t ftype = 1; + int32_t n_fft = 512; // FFT size for mel spectrogram + float eps = 1e-5f; + int32_t subsampling_factor = 8; + int32_t n_subsampling_channels = 256; + int32_t n_conv_kernel = 9; + int32_t n_pred_dim = 640; + int32_t n_pred_layers = 2; + int32_t n_tdt_durations = 5; + int32_t n_max_tokens = 10; + + parakeet_arch arch = PARAKEET_ARCH_TDT; +}; + +struct parakeet_layer_encoder { + struct ggml_tensor * norm_ff1_w = nullptr; + struct ggml_tensor * norm_ff1_b = nullptr; + + struct ggml_tensor * ff1_linear1_w = nullptr; + struct ggml_tensor * ff1_linear2_w = nullptr; + + struct ggml_tensor * norm_conv_w = nullptr; + struct ggml_tensor * norm_conv_b = nullptr; + + struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1 + struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv + struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight + struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias + struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean + struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var + struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked + struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2 + + struct ggml_tensor * norm_attn_w = nullptr; + struct ggml_tensor * norm_attn_b = nullptr; + + struct ggml_tensor * attn_pos_bias_u = nullptr; + struct ggml_tensor * attn_pos_bias_v = nullptr; + struct ggml_tensor * attn_q_w = nullptr; + struct ggml_tensor * attn_k_w = nullptr; + struct ggml_tensor * attn_v_w = nullptr; + struct ggml_tensor * attn_out_w = nullptr; + struct ggml_tensor * attn_pos_w = nullptr; + + struct ggml_tensor * norm_ff2_w = nullptr; + struct ggml_tensor * norm_ff2_b = nullptr; + + struct ggml_tensor * ff2_linear1_w = nullptr; + struct ggml_tensor * ff2_linear2_w = nullptr; + + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; +}; + +struct parakeet_lsmt_layer { + struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight + struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight + struct ggml_tensor * b_h = nullptr; // bias (ih folded into hh at conversion time) +}; + +struct parakeet_prediction_network { + struct ggml_tensor * embed_w = nullptr; + + std::vector<parakeet_lsmt_layer> lstm_layer; +}; + +struct parakeet_joint_network { + struct ggml_tensor * pred_w = nullptr; + struct ggml_tensor * pred_b = nullptr; + struct ggml_tensor * enc_w = nullptr; + struct ggml_tensor * enc_b = nullptr; + struct ggml_tensor * net_w = nullptr; + struct ggml_tensor * net_b = nullptr; +}; + +struct parakeet_model { + parakeet_filters filters; + parakeet_hparams hparams; + + struct ggml_tensor * enc_pre_out_w = nullptr; + struct ggml_tensor * enc_pre_out_b = nullptr; + struct ggml_tensor * enc_pre_conv_0_w = nullptr; + struct ggml_tensor * enc_pre_conv_0_b = nullptr; + struct ggml_tensor * enc_pre_conv_2_w = nullptr; + struct ggml_tensor * enc_pre_conv_2_b = nullptr; + struct ggml_tensor * enc_pre_conv_3_w = nullptr; + struct ggml_tensor * enc_pre_conv_3_b = nullptr; + struct ggml_tensor * enc_pre_conv_5_w = nullptr; + struct ggml_tensor * enc_pre_conv_5_b = nullptr; + struct ggml_tensor * enc_pre_conv_6_w = nullptr; + struct ggml_tensor * enc_pre_conv_6_b = nullptr; + + std::vector<parakeet_layer_encoder> layers; + + parakeet_prediction_network prediction; + + parakeet_joint_network joint; + + std::vector<uint32_t> tdt_durations; + + std::vector<ggml_context *> ctxs; + + std::vector<ggml_backend_buffer_t> buffers; + + int n_loaded = 0; + std::map<std::string, struct ggml_tensor *> tensors; +}; + +struct parakeet_lstm_state_layer { + struct ggml_tensor * h_state = nullptr; + struct ggml_tensor * c_state = nullptr; +}; + +struct parakeet_lstm_state { + std::vector<parakeet_lstm_state_layer> layer; + + std::vector<uint8_t> ctx_buf; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct parakeet_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_predict_us = 0; + int64_t t_predict_build_us = 0; // time spent building the prediction graph + int64_t t_predict_alloc_us = 0; // time spent in ggml_backend_sched_alloc_graph + int64_t t_predict_compute_us = 0; // time spent in ggml_graph_compute_helper + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_predict = 0; // number of prediction network calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + parakeet_mel mel; + + parakeet_batch batch; + + int n_frames = 0; + + std::vector<ggml_backend_t> backends; + + parakeet_sched sched_encode; + parakeet_sched sched_decode; + + // outputs from encoder stages + struct ggml_tensor * enc_out = nullptr; + struct ggml_tensor * pred_out = nullptr; + + std::vector<uint8_t> enc_out_buf; + ggml_backend_buffer_t enc_out_buffer = nullptr; + + std::vector<uint8_t> pred_out_buf; + ggml_backend_buffer_t pred_out_buffer = nullptr; + + struct ggml_tensor * attn_mask = nullptr; + + std::vector<float> inp_mel; + std::vector<float> inp_mask; + + std::vector<float> logits; + + std::vector<parakeet_segment> result_all; + + std::vector<parakeet_token> decoded_tokens; + std::vector<parakeet_token_data> decoded_token_data; + + std::string path_model; + + int32_t n_audio_ctx = 0; + int32_t sched_encode_n_audio_ctx = 0; + + parakeet_lstm_state lstm_state; +}; + +// FFT cache for mel spectrogram computation +struct parakeet_mel_cache { + int n_fft = 0; + + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + std::vector<float> sin_vals; + std::vector<float> cos_vals; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector<float> hann_window; + + // Window function from model (Parakeet uses actual window from training) + std::vector<float> window; + + void init(int fft_size) { + n_fft = fft_size; + sin_vals.resize(n_fft); + cos_vals.resize(n_fft); + hann_window.resize(n_fft); + + fill_sin_cos_table(); + fill_hann_window(n_fft, true, hann_window.data()); + } + + void fill_sin_cos_table() { + for (int i = 0; i < n_fft; i++) { + double theta = (2 * M_PI * i) / n_fft; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } +}; + +struct parakeet_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; + ggml_type itype = ggml_type::GGML_TYPE_F16; + + parakeet_context_params params; + + parakeet_model model; + parakeet_vocab vocab; + + parakeet_state * state = nullptr; + + parakeet_mel_cache mel_cache; + + std::string path_model; +}; + +struct parakeet_global { + // We save the log callback globally + ggml_log_callback log_callback = parakeet_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static parakeet_global g_state; + +static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; + +static inline int utf8_codepoint_len(unsigned char c) { + if ((c & 0x80) == 0x00) return 1; + if ((c & 0xE0) == 0xC0) return 2; + if ((c & 0xF0) == 0xE0) return 3; + if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +static bool is_sentencepiece_control(const std::string & piece) { + return piece == "<unk>" || piece == "<s>" || piece == "</s>" || piece == "[BLANK]"; +} + +static std::string sentencepiece_normalize(const std::string & text) { + std::string normalized; + normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); + normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix + + for (unsigned char c : text) { + if (std::isspace(c)) { + normalized += PARAKEET_SPM_SPACE; + } else { + normalized += static_cast<char>(c); + } + } + + return normalized; +} + +static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { + if (is_sentencepiece_control(piece)) { + return ""; + } + + std::string text; + text.reserve(piece.size()); + + size_t pos = 0; + while (pos < piece.size()) { + if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { + if (!is_first_piece || !text.empty()) { + text += ' '; + } + pos += PARAKEET_SPM_SPACE.size(); + continue; + } + + text += piece[pos]; + ++pos; + } + + return text; +} + + +static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { + parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); + batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void parakeet_batch_free(struct parakeet_batch batch) { + if (batch.token) free(batch.token); + if (batch.i_time) free(batch.i_time); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + + +static size_t parakeet_sched_size(struct parakeet_sched & allocr) { + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) { + auto & sched = allocr.sched; + auto & meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + + if (!sched) { + PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); + return false; + } + + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); + + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + ggml_backend_sched_free(sched); + sched = nullptr; + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + +static void parakeet_sched_free(struct parakeet_sched & sched) { + if (sched.sched) { + ggml_backend_sched_free(sched.sched); + sched.sched = nullptr; + } + + sched.meta.clear(); +} + + +template<typename T> +static void read_safe(parakeet_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool parakeet_lstm_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_layer, + int n_pred_dim) { + parakeet_lstm_state & lstm_state = pstate.lstm_state; + + lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2); + lstm_state.layer.resize(n_layer); + + struct ggml_init_params params = { + /*.mem_size =*/ lstm_state.ctx_buf.size(), + /*.mem_buffer =*/ lstm_state.ctx_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__); + return false; + } + + + for (int il = 0; il < n_layer; ++il) { + lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + } + + lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!lstm_state.buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__); + return false; + } + + ggml_backend_buffer_clear(lstm_state.buffer, 0); + + ggml_free(ctx); + + return true; +} + +static bool parakeet_pred_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_pred_dim) { + pstate.pred_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.pred_out_buf.size(), + /*.mem_buffer =*/ pstate.pred_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__); + return false; + } + + pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.pred_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static bool parakeet_enc_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_audio_state, + int n_frames_max) { + pstate.enc_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.enc_out_buf.size(), + /*.mem_buffer =*/ pstate.enc_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor context\n", __func__); + return false; + } + + pstate.enc_out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_frames_max); + pstate.enc_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.enc_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static ggml_backend_t parakeet_backend_init_gpu(const parakeet_context_params & params) { + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); + + ggml_backend_dev_t dev = nullptr; + + int cnt = 0; + if (params.use_gpu) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); + enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur); + const char * dev_name = ggml_backend_dev_name(dev_cur); + PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) { + PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt); + if (cnt == params.gpu_device) { + dev = dev_cur; + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + if (dev == nullptr) { + PARAKEET_LOG_INFO("%s: no GPU found\n", __func__); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); + if (!result) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + } + + return result; +} + +static std::vector<ggml_backend_t> parakeet_backend_init(const parakeet_context_params & params) { + std::vector<ggml_backend_t> result; + + ggml_backend_t backend_gpu = parakeet_backend_init_gpu(params); + + if (backend_gpu) { + result.push_back(backend_gpu); + } + + // ACCEL backends + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + continue; + } + result.push_back(backend); + } + } + + ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + result.push_back(backend_cpu); + + return result; +} + +using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>; + +static buft_list_t make_buft_list(parakeet_context_params & params) { + // Prio order: GPU -> CPU Extra -> CPU + buft_list_t buft_list; + + // GPU + if (params.use_gpu) { + int cnt = 0; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) { + if (cnt == params.gpu_device) { + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + } + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + // CPU Extra + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // CPU + buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type()); + + return buft_list; +} + +static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS + case GGML_OP_GET_ROWS: + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + if (op == GGML_OP_MUL_MAT) { + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } else if (op == GGML_OP_GET_ROWS) { + int64_t num_indices = 8; + ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); + op_tensor = ggml_get_rows(ctx, w, indices); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; +} + + +// load the model from a ggml file +// + +// see the convert-parakeet-to-ggml.py script for details +// +static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) { + PARAKEET_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + parakeet_hparams hparams; + { + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + read_safe(loader, hparams.n_fft); + read_safe(loader, hparams.subsampling_factor); + read_safe(loader, hparams.n_subsampling_channels); + read_safe(loader, hparams.n_conv_kernel); + read_safe(loader, hparams.n_pred_dim); + read_safe(loader, hparams.n_pred_layers); + read_safe(loader, hparams.n_tdt_durations); + read_safe(loader, hparams.n_max_tokens); + + hparams.arch = PARAKEET_ARCH_TDT; + wctx.model.hparams = hparams; + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype); + if (wctx.wtype == GGML_TYPE_COUNT) { + PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype); + return false; + } + + const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown"; + PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); + PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); + PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps); + PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); + PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); + PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); + PARAKEET_LOG_INFO("%s: n_conv_kernel = %d\n", __func__, hparams.n_conv_kernel); + PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); + PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); + PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); + PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fb); + + filters.data.resize(filters.n_mel * filters.n_fb); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load window function + { + int32_t n_window = 0; + read_safe(loader, n_window); + + wctx.mel_cache.window.resize(n_window); + loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float)); + +#ifdef GGML_BIG_ENDIAN + for (auto & datum : wctx.mel_cache.window) { + datum = byteswap(datum); + } +#endif + + PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window); + } + + // load TDT (Token and Duration Transducer) values + { + auto & tdt_durations = wctx.model.tdt_durations; + tdt_durations.resize(hparams.n_tdt_durations); + loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + + PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__); + for (const auto value : tdt_durations) { + PARAKEET_LOG_INFO("%u ", value); + } + PARAKEET_LOG_INFO("]\n"); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + std::string word; + std::vector<char> tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + vocab.max_token_length = std::max(vocab.max_token_length, word.size()); + } + // Blank token for transducer is at index n_vocab (8192), outside the vocabulary + int blank_id = n_vocab; + vocab.token_blank = blank_id; + vocab.id_to_token[blank_id] = "[BLANK]"; + vocab.token_to_id["[BLANK]"] = blank_id; + + // Set special token IDs by looking them up in the loaded vocabulary + // These are from the SentencePiece vocab file loaded above + if (vocab.token_to_id.find("<unk>") != vocab.token_to_id.end()) { + vocab.token_unk = vocab.token_to_id.at("<unk>"); + } else { + vocab.token_unk = 0; // Fallback + } + + if (vocab.token_to_id.find("<s>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<s>"); + } else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>"); + } else { + vocab.token_bos = 0; // Fallback + } + + if (vocab.token_to_id.find("</s>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("</s>"); + } else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("<|endoftext|>"); + } else { + vocab.token_eos = 0; // Fallback + } + + vocab.n_vocab = model.hparams.n_vocab; + + PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n", + __func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos); + } + + const ggml_type wtype = wctx.wtype; + + + const int n_audio_layer = hparams.n_audio_layer; + + // Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6) + size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6; + + std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + wctx.model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // Create a list of available bufts, in priority order + buft_list_t buft_list = make_buft_list(wctx.params); + + auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * { + ggml_op op = PARAKEET_TENSOR_INFO.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", + PARAKEET_TENSOR_NAMES.at(type))); + } + + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + + std::string tensor_name; + if (layer >= 0) { + tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer); + } else { + tensor_name = PARAKEET_TENSOR_NAMES.at(type); + } + + wctx.model.tensors[tensor_name] = tensor; + + return tensor; + }; + + // prepare tensors for the weights + + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + + const int n_audio_state = hparams.n_audio_state; + + model.layers.resize(n_audio_layer); + + // Encoder pre_encode + const int n_subsampling_channels = hparams.n_subsampling_channels; + const int n_pre_enc_features = (hparams.n_mels / hparams.subsampling_factor) * n_subsampling_channels; + model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_pre_enc_features, n_audio_state)); + ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w"); + model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); + ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b"); + + model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w"); + model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b"); + + model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w"); + model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b"); + + model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w"); + model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b"); + + model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w"); + model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b"); + + model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w"); + model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b"); + + // Encoder layers + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers[i]; + + // Feed forward 1 + layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i); + layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i); + + // Convolution module + layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i); + layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i); + layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i); + ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i); + layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_conv_kernel, n_audio_state), i); + ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i); + layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i); + layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i); + layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i); + layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i); + layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i); + + // Self attention + layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i); + + // Feed forward 2 + layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + + // Output norm + layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + } + + // Prediction network (decoder) + const int dec_hidden = hparams.n_pred_dim; + const int n_pred_embed = hparams.n_vocab + 1; // vocab + blank token + const int n_lstm_gates = 4 * dec_hidden; // 4 LSTM gates + const int n_joint_out = hparams.n_vocab + hparams.n_tdt_durations + 1; // vocab + durations + blank + + // The prediction/joint hidden dimension is 640, which is not a multiple of the + // K-quant block size (256). For K-quant models, we keep these tensors at F32. + const int blck = ggml_blck_size(wtype); + const ggml_type pred_wtype = (blck > 1 && dec_hidden % blck != 0) ? GGML_TYPE_F32 : wtype; + const ggml_type join_wtype = pred_wtype; + + model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_pred_embed)); + model.prediction.lstm_layer.resize(hparams.n_pred_layers); + for (int i = 0; i < hparams.n_pred_layers; ++i) { + auto & layer = model.prediction.lstm_layer[i]; + layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.ih_w, "pred_%d_ih_w", i); + + layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.hh_w, "pred_%d_hh_w", i); + + layer.b_h = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_H, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_lstm_gates), i); + ggml_format_name(layer.b_h, "pred_%d_b_h", i); + } + + // Joint network + model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, dec_hidden)); + ggml_set_name(model.joint.pred_w, "pred_w"); + model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.pred_b, "pred_b"); + model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden)); + ggml_set_name(model.joint.enc_w, "enc_w"); + model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.enc_b, "enc_b"); + model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, n_joint_out)); + ggml_set_name(model.joint.net_w, "net_w"); + model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_joint_out)); + ggml_set_name(model.joint.net_b, "net_b"); + + ggml_free(ctx); + + // allocate tensors in the backend buffers + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + wctx.model.buffers.emplace_back(buf); + + size_t size_main = ggml_backend_buffer_get_size(buf); + PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } + + // load weights + { + size_t total_size = 0; + + auto & tensors_map = wctx.model.tensors; + int & n_loaded = wctx.model.n_loaded; + + n_loaded = 0; + + std::vector<char> read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector<char> tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (tensors_map.find(name) == tensors_map.end()) { + PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = tensors_map[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + if (ggml_backend_buffer_is_host(tensor->buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + n_loaded++; + } + + PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (n_loaded == 0) { + PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (n_loaded != (int) tensors_map.size()) { + PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded); + return false; + } + } + + auto & buffers = wctx.model.buffers; + for (auto & buf : buffers) { + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// conv subsampling + conformer encoder +static struct ggml_cgraph * parakeet_build_graph_encode(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_mel_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx; + const int n_mels = hparams.n_mels; + const int n_layer = hparams.n_audio_layer; + const int n_state = hparams.n_audio_state; + const float fc_factor = 0.5f; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_encode.meta.size(), + /*.mem_buffer =*/ pstate.sched_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Conv subsampling + + // [freq, time] + struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_mel_time, 1, 1); + ggml_set_name(mel, "mel"); + ggml_set_input(mel); + + // [freq, time, channels, batch] + struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b); + ggml_set_name(cur, "pre_conv_0"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_0_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b); + ggml_set_name(cur, "pre_conv_2"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b); + ggml_set_name(cur, "pre_conv_3"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_3_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1); + ggml_set_name(cur, "pre_conv_5_direct"); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b); + ggml_set_name(cur, "pre_conv_5"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b); + ggml_set_name(cur, "pre_conv_6"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_6_relu"); + + // [freq, time, chan] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // [freq, chan, time] + cur = ggml_cont(ctx0, cur); + + const int n_freq = cur->ne[0]; // 16 + const int n_chan = cur->ne[1]; // 256 + const int n_frames = cur->ne[2]; // time + + // [freq, time, chan, batch] -> [(freq * chan), time] + cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames); + + cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur); + cur = ggml_add(ctx0, cur, model.enc_pre_out_b); + + ggml_set_name(cur, "pre_enc_out"); + + // Encoder + // cur: [n_state, n_enc_time] + + const int n_time = cur->ne[1]; + const bool local_attn = n_time > PARAKEET_LOCAL_ATTN_THRESHOLD; + const int att_left = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int att_right = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int window_size = local_attn ? att_left + att_right + 1 : 2 * n_time - 1; + const int d_half = n_state / 2; + const int mask_dim = local_attn ? window_size : n_time; + + // mask [key, n_time] + struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mask_dim, n_time); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + struct ggml_tensor * local_mask = nullptr; + if (local_attn) { + const int chunk = att_left + att_right; + local_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, chunk + window_size - 1, chunk); + ggml_set_name(local_mask, "local_mask"); + ggml_set_input(local_mask); + } + + struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half); + ggml_set_name(pos_freqs, "pos_freqs"); + ggml_set_input(pos_freqs); + + struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size); + ggml_set_name(rel_positions, "rel_positions"); + ggml_set_input(rel_positions); + + struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1); + struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions); + + struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size); + // [n_state, window_size] + struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size); + ggml_set_name(pos_emb, "pos_emb"); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + // FFN1 + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_res", il); + + // norm + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff1_w), layer.norm_ff1_b); + ggml_format_name(cur, "enc_%d_ffn_norm_1", il); + + // ffn_1 + cur = ggml_mul_mat(ctx0, layer.ff1_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_silu", il); + + cur = ggml_mul_mat(ctx0, layer.ff1_linear2_w, cur); + ggml_format_name(cur, "enc_%d_ffn_1", il); + + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + ggml_format_name(cur, "enc_%d_res_ffn", il); + } + + // self attention block using relative positional encoding computed in graph. + { + // [feat, time_frames, 1, 1] + struct ggml_tensor * residual = cur; + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_attn_w), layer.norm_attn_b); + ggml_format_name(cur, "enc_%d_attn_norm", il); + + const int n_head = hparams.n_audio_head; + const int d_head = n_state / n_head; + + // [feat, time_frames, 1, 1] + struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); + struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); + struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); + + Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time); + K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); + V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); + + struct ggml_tensor * pos = ggml_mul_mat(ctx0, layer.attn_pos_w, pos_emb); + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size); + pos = ggml_cont(ctx0, ggml_permute(ctx0, pos, 0, 2, 1, 3)); + + if (local_attn) { + const int chunk = att_left + att_right; + const int n_group = (n_time + chunk - 1) / chunk; + const int n_time_padded = n_group * chunk; + const int n_kv_chunk = chunk + window_size - 1; + const int n_kv_dense = n_kv_chunk * n_group; + const bool need_padding = n_time_padded > n_time; + + Q_cur = ggml_cont(ctx0, ggml_permute(ctx0, Q_cur, 0, 2, 1, 3)); + K_cur = ggml_cont(ctx0, ggml_permute(ctx0, K_cur, 0, 2, 1, 3)); + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 0, 2, 1, 3)); + + // content bias + struct ggml_tensor * bias_u = ggml_reshape_3d(ctx0, layer.attn_pos_bias_u, d_head, 1, n_head); + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, bias_u); + + // position bias + struct ggml_tensor * bias_v = ggml_reshape_3d(ctx0, layer.attn_pos_bias_v, d_head, 1, n_head); + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, bias_v); + + // right pad the time_frame. + struct ggml_tensor * Q_u_padded = need_padding ? + ggml_pad_ext(ctx0, Q_u, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : Q_u; + Q_u_padded = ggml_reshape_4d(ctx0, Q_u_padded, d_head, chunk, n_group, n_head); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * K_padded = ggml_pad_ext(ctx0, K_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > K_padded->ne[1]) { + K_padded = ggml_pad_ext(ctx0, K_padded, 0, 0, 0, n_kv_dense - K_padded->ne[1], 0, 0, 0, 0); + } + + // Create a 4d tensor where each group spans a wide window of + // 512 keys (n_kv_chunk), but moving to the next group (nb[2]) + // only jumps forward by 256 frames (chunk * nb[1]). This creates + // a 256 frame overlap, shared keys in RAM without copies. + struct ggml_tensor * K_chunk = ggml_view_4d(ctx0, K_padded, + d_head, n_kv_chunk, n_group, n_head, + K_padded->nb[1], + (size_t) chunk * K_padded->nb[1], + K_padded->nb[2], + 0); + K_chunk = ggml_cont(ctx0, K_chunk); + + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_chunk, Q_u_padded); + + // The above mul_mat operation, combined with K_chunk's overlapping + // frames, produces a dense matrix. But some of the results in + // this matrix were computed for keys that aren't part of that + // query's window. So we shift each row to keep only the results + // that we want. + content_scores = ggml_view_4d(ctx0, content_scores, + window_size, chunk, n_group, n_head, + (size_t) (chunk + window_size) * content_scores->nb[0], + content_scores->nb[2], + content_scores->nb[3], + 0); + content_scores = ggml_cont(ctx0, content_scores); + + // ungrouping. + content_scores = ggml_reshape_3d(ctx0, content_scores, window_size, n_time_padded, n_head); + + // remove padding if padding was applied (truncating to n_time). + if (need_padding) { + content_scores = ggml_view_3d(ctx0, content_scores, + window_size, n_time, n_head, + content_scores->nb[1], + content_scores->nb[2], + 0); + } + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + + // attention_score = content similarity + relative position scores + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + + attn_scores = ggml_soft_max_ext(ctx0, attn_scores, attn_mask, 1.0f / std::sqrt(d_head), 0.0f); + + // right pad the probabilites. + struct ggml_tensor * probs_padded = need_padding ? + ggml_pad_ext(ctx0, attn_scores, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : attn_scores; + + probs_padded = ggml_reshape_4d(ctx0, probs_padded, window_size, chunk, n_group, n_head); + probs_padded = ggml_pad_ext(ctx0, probs_padded, 0, chunk, 0, 0, 0, 0, 0, 0); + probs_padded = ggml_view_4d(ctx0, probs_padded, + n_kv_chunk, chunk, n_group, n_head, + (size_t) n_kv_chunk * probs_padded->nb[0], + probs_padded->nb[2], + probs_padded->nb[3], + 0); + probs_padded = ggml_cont(ctx0, probs_padded); + probs_padded = ggml_mul(ctx0, probs_padded, local_mask); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * V_padded = ggml_pad_ext(ctx0, V_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > V_padded->ne[1]) { + V_padded = ggml_pad_ext(ctx0, V_padded, 0, 0, 0, n_kv_dense - V_padded->ne[1], 0, 0, 0, 0); + } + + V_padded = ggml_cont(ctx0, ggml_transpose(ctx0, V_padded)); + + struct ggml_tensor * V_chunk = ggml_view_4d(ctx0, V_padded, + n_kv_chunk, d_head, n_group, n_head, + V_padded->nb[1], + (size_t) chunk * V_padded->nb[0], + V_padded->nb[2], + 0); + V_chunk = ggml_cont(ctx0, V_chunk); + + cur = ggml_mul_mat(ctx0, V_chunk, probs_padded); + // ungroup. + cur = ggml_reshape_3d(ctx0, cur, d_head, n_time_padded, n_head); + // unpad + if (need_padding) { + cur = ggml_view_3d(ctx0, cur, d_head, n_time, n_head, cur->nb[1], cur->nb[2], 0); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } else { + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_u); + ggml_format_name(Q_u, "enc_%d_attn_q_u", il); + + struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3); + struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3); + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep); + ggml_format_name(content_scores, "enc_%d_attn_content_scores", il); + + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v", il); + + Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3); + Q_v = ggml_cont(ctx0, Q_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il); + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il); + + // Relative position shifting is performed in the following block. + // Some more details on the operations performed below can be found here: + // https://github.com/danbev/learning-ai/blob/main/notes/whisper/parakeet.md#relative-position-shift + { + const auto pos_window = rel_pos_scores->ne[0]; + const auto n_frame = rel_pos_scores->ne[1]; + const auto n_head_cur = rel_pos_scores->ne[2]; + + rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0); + rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0); + + rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head_cur); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il); + + int center = pos_window / 2; + size_t offset = rel_pos_scores->nb[0] * (center+1); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + n_frame, pos_window, n_head_cur, + (pos_window) * 4, + rel_pos_scores->nb[2], + offset); + + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + content_scores->ne[0], + content_scores->ne[1], + rel_pos_scores->ne[2], + rel_pos_scores->nb[1], + rel_pos_scores->nb[2], + 0); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il); + } + + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + ggml_format_name(attn_scores, "enc_%d_attn_scores", il); + attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head)); + attn_scores = ggml_add(ctx0, attn_scores, attn_mask); + ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores); + ggml_format_name(probs, "enc_%d_attn_probs", il); + + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3)); + ggml_format_name(V_cur, "enc_%d_attn_v_cur", il); + cur = ggml_mul_mat(ctx0, probs, V_cur); + ggml_format_name(cur, "enc_%d_attn_inp", il); + + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } + ggml_format_name(cur, "enc_%d_attn_out", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_attn_res", il); + } + + // Convolution + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_residual_conv", il); + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_conv_w), layer.norm_conv_b); + ggml_format_name(cur, "enc_%d_norm_conv", il); + + // pointwise 1d convolution: [1024, 138] -> [2048, 138] + cur = ggml_mul_mat(ctx0, layer.conv_pw1_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw1", il); + + { + int64_t d = cur->ne[0] / 2; + struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0); + struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]); + + cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate)); + ggml_format_name(cur, "enc_%d_conv_glu", il); + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // use ggml_ssm_conv for f32 precision + const int dw_pad = (hparams.n_conv_kernel - 1) / 2; + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_roll(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + ggml_format_name(cur, "enc_%d_conv_dw_pad", il); + + cur = ggml_ssm_conv(ctx0, cur, layer.conv_dw_w); + ggml_format_name(cur, "enc_%d_conv_1d_dw", il); + + cur = ggml_sub(ctx0, cur, layer.conv_bn_mean); + struct ggml_tensor * std = ggml_sqrt(ctx0, layer.conv_bn_var); + cur = ggml_div(ctx0, cur, std); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.conv_bn_w), layer.conv_bn_b); + ggml_format_name(cur, "enc_%d_conv_bn", il); + + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_conv_silu", il); + + cur = ggml_mul_mat(ctx0, layer.conv_pw2_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw2", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_conv_res", il); + } + + // FFN2 + { + struct ggml_tensor * residual = cur; + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff2_w), layer.norm_ff2_b); + ggml_format_name(cur, "enc_%d_ffn_norm_2", il); + + cur = ggml_mul_mat(ctx0, layer.ff2_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + cur = ggml_mul_mat(ctx0, layer.ff2_linear2_w, cur); + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5)); + ggml_format_name(cur, "enc_%d_ffn_res", il); + } + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_out_w), layer.norm_out_b); + } + + ggml_set_name(cur, "encoder_out"); + pstate.n_frames = cur->ne[1]; + + struct ggml_tensor * enc_out_view = ggml_view_2d(ctx0, pstate.enc_out, n_state, pstate.n_frames, pstate.enc_out->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, cur, enc_out_view)); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_encode_internal( + parakeet_context & pctx, + parakeet_state & pstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + auto & sched = pstate.sched_encode.sched; + + ggml_cgraph * gf = parakeet_build_graph_encode(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set mel input + { + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + + const auto & mel_inp = pstate.mel; + const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx; + + assert(mel->type == GGML_TYPE_F32); + assert(mel_inp.n_mel == pctx.model.hparams.n_mels); + + pstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = pstate.inp_mel.data(); + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len); + + memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, (i1 - i0) * mel_inp.n_mel * sizeof(float)); + + ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + } + + // set attention mask + { + struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask"); + const int n_q = attn_mask->ne[1]; + const int n_k = attn_mask->ne[0]; + + const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor - 1) / subsampl_factor; + + std::vector<float> mask_data(n_q * n_k); + const float mask_value = -1e30f; + + if (n_k == n_q) { // full attention + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + mask_data[q * n_k + k] = (k >= n_tokens_real) ? mask_value : 0.0f; + } + } + } else { // local attention + const int att_left = n_k / 2; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int key = q - att_left + k; + mask_data[q * n_k + k] = (key >= 0 && key < n_tokens_real) ? 0.0f : mask_value; + } + } + } + ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set local attention skew mask + if (struct ggml_tensor * local_mask = ggml_graph_get_tensor(gf, "local_mask")) { + const int n_k = local_mask->ne[0]; + const int n_q = local_mask->ne[1]; + + std::vector<float> mask_data(n_q * n_k); + const int window_size = n_k - n_q + 1; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int rel = k - q; + mask_data[q * n_k + k] = (rel >= 0 && rel < window_size) ? 1.0f : 0.0f; + } + } + ggml_backend_tensor_set(local_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set positional frequency + { + struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs"); + const int d_half = pos_freqs_t->ne[0]; + const int n_state = pctx.model.hparams.n_audio_state; + const float log_10000 = logf(10000.0f); + std::vector<float> freqs(d_half); + for (int k = 0; k < d_half; ++k) { + freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state))); + } + ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float)); + } + + // set relative position offsets + { + struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions"); + const int window_size = rel_pos_t->ne[1]; + std::vector<float> pos(window_size); + if (window_size == PARAKEET_LOCAL_ATTN_WINDOW * 2 + 1) { + for (int t = 0; t < window_size; ++t) { + pos[t] = float(PARAKEET_LOCAL_ATTN_WINDOW - t); + } + } else { + const int n_time = (window_size + 1) / 2; + for (int t = 0; t < window_size; ++t) { + pos[t] = float(n_time - 1 - t); + } + } + ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + pstate.t_encode_us += ggml_time_us() - t_start_us; + pstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_ensure_encode_sched( + parakeet_context & pctx, + parakeet_state & pstate, + int n_audio_ctx) { + if (pstate.sched_encode.sched && pstate.sched_encode_n_audio_ctx == n_audio_ctx) { + return true; + } + + parakeet_sched_free(pstate.sched_encode); + + const int32_t prev_n_audio_ctx = pstate.n_audio_ctx; + pstate.n_audio_ctx = n_audio_ctx; + + const int subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_frames_max = (n_audio_ctx + subsampl_factor - 1) / subsampl_factor; + if (n_frames_max > pstate.enc_out->ne[1]) { + ggml_backend_buffer_free(pstate.enc_out_buffer); + pstate.enc_out_buffer = nullptr; + pstate.enc_out = nullptr; + + if (!parakeet_enc_state_init(pstate, pstate.backends[0], pctx.model.hparams.n_audio_state, n_frames_max)) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + } + + const bool ok = parakeet_sched_graph_init(pstate.sched_encode, pstate.backends, + [&]() { + return parakeet_build_graph_encode(pctx, pstate); + }); + + if (!ok) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + + pstate.sched_encode_n_audio_ctx = n_audio_ctx; + return true; +} + +static struct ggml_tensor * parakeet_build_graph_lstm_layer( + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * x_t, // the current input token embedding + struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed) + struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_h, // folded ih+hh bias (4 bias tensors packed) + struct ggml_tensor * h_state, // this layers hidden state + struct ggml_tensor * c_state, // this layers cell state + int li) { // layer index (for tensor naming) + + ggml_format_name(x_t, "lstm_layer_%d_x_t", li); + ggml_format_name(h_state, "lstm_layer_%d_h_state", li); + ggml_format_name(c_state, "lstm_layer_%d_c_state", li); + + // The 4 gates (i, f, o, c) are packed in the same weight tensor. + struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t); + + // Hidden-to-Hidden Projections are also packed in the same weight tensor. + // b_h holds the folded ih+hh bias (see parakeet_model_load), so it is + // the only bias that needs to be added here. + struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state); + hid_gates = ggml_add(ctx0, hid_gates, b_h); + + // Combine the input and hidden contributions of the gates. + struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates); + ggml_format_name(gates, "lstm_layer_%d_gates", li); + + const int h_dim = h_state->ne[0]; + const size_t row_size = ggml_row_size(gates->type, h_dim); + + // The gates are packed as [i, f, o, c] (reordered at convert time, see + // parakeet_model_load), so the three sigmoid-gated outputs (i, f, o) are + // contiguous and can be computed with a single ggml_sigmoid call. + struct ggml_tensor * ifo = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, 3 * h_dim, 0)); + ggml_format_name(ifo, "lstm_layer_%d_ifo", li); + + // 1. Input Gate at time t. + struct ggml_tensor * i_t = ggml_view_1d(ctx0, ifo, h_dim, 0 * row_size); + ggml_format_name(i_t, "lstm_layer_%d_i_t", li); + + // Forget gate. + struct ggml_tensor * f_t = ggml_view_1d(ctx0, ifo, h_dim, 1 * row_size); + ggml_format_name(f_t, "lstm_layer_%d_f_t", li); + + // Output gate. + struct ggml_tensor * o_t = ggml_view_1d(ctx0, ifo, h_dim, 2 * row_size); + ggml_format_name(o_t, "lstm_layer_%d_o_t", li); + + // Cell gate. + struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size)); + ggml_format_name(c_t, "lstm_layer_%d_c_t", li); + + // Calculate the new cell state. + struct ggml_tensor * c_new = ggml_add(ctx0, + ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state. + ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state)); + + // Calculate the new hidden state. + struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new)); + ggml_set_output(h_new); + ggml_format_name(h_new, "lstm_layer_%d_h_new", li); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state)); + + return h_new; +} + +static struct ggml_cgraph * parakeet_build_graph_prediction( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Prediction Network + struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(token, "token_inp"); + ggml_set_input(token); + + struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token); + + struct ggml_tensor * inpL = token_embd; + + for (int il = 0; il < hparams.n_pred_layers; ++il) { + inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL, + model.prediction.lstm_layer[il].ih_w, + model.prediction.lstm_layer[il].hh_w, + model.prediction.lstm_layer[il].b_h, + pstate.lstm_state.layer[il].h_state, + pstate.lstm_state.layer[il].c_state, + il); + } + + struct ggml_tensor * pred_out = inpL; + ggml_format_name(pred_out, "lstm_pred_out"); + + // Project the prediction network output to the joint network hidden dimension. + struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out); + pred = ggml_add(ctx0, pred, model.joint.pred_b); + ggml_set_name(pred, "h_pred"); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, pred, pstate.pred_out)); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * parakeet_build_graph_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + struct ggml_tensor * pred = pstate.pred_out; + ggml_format_name(pred, "pred"); + + const int t_idx = batch.i_time[0]; + struct ggml_tensor * enc_out = ggml_view_1d(ctx0, pstate.enc_out, hparams.n_audio_state, + (size_t) t_idx * pstate.enc_out->nb[1]); + ggml_format_name(enc_out, "enc_out_view"); + + // Project the encoder output to the joint network hidden dimension. + struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out); + enc = ggml_add(ctx0, enc, model.joint.enc_b); + ggml_set_name(enc, "enc"); + + struct ggml_tensor * joint = ggml_add(ctx0, enc, pred); + ggml_set_name(joint, "joint"); + joint = ggml_relu(ctx0, joint); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint); + logits = ggml_add(ctx0, logits, model.joint.net_b); + ggml_set_output(logits); + ggml_set_name(logits, "logits"); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, logits); + struct ggml_tensor * log_probs = ggml_log(ctx0, probs); + ggml_set_output(log_probs); + ggml_format_name(log_probs, "log_probs"); + + ggml_build_forward_expand(gf, log_probs); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_predict( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + + const int n_tokens = batch.n_tokens; + + const int64_t t_start_us = ggml_time_us(); + + { + auto & sched = pstate.sched_decode.sched; + + const int64_t t_build_start_us = ggml_time_us(); + ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false); + pstate.t_predict_build_us += ggml_time_us() - t_build_start_us; + + const int64_t t_alloc_start_us = ggml_time_us(); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + pstate.t_predict_alloc_us += ggml_time_us() - t_alloc_start_us; + + // set the inputs + { + struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp"); + ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp)); + } + + const int64_t t_compute_start_us = ggml_time_us(); + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + pstate.t_predict_compute_us += ggml_time_us() - t_compute_start_us; + } + + pstate.t_predict_us += ggml_time_us() - t_start_us; + pstate.n_predict++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + auto & logits_out = pstate.logits; + + struct ggml_tensor * logits; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + logits = ggml_graph_node(gf, -1); + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + } + + const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token + logits_out.resize(n_tokens * n_logits); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits); + } + + if (batch.n_tokens == 1) { + pstate.t_decode_us += ggml_time_us() - t_start_us; + pstate.n_decode++; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + // check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81 + if (!token_str.empty()) { + if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') { + return true; + } + } + return false; +} + +static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + static const std::string punct_chars = ".,!?;:'\"-()[]{}"; + + if (token_str.empty()) { + return false; + } + + std::string clean_token = token_str; + if (clean_token.find("\xE2\x96\x81") == 0) { + clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character + } else if (clean_token[0] == '_') { + clean_token = clean_token.substr(1); + } + + return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos; +} + +// Collapse punctuation timestamps to match the original Parakeet model. +// Punctuations symbols like ',', '.' and others are not spoken words but the +// model will still produce a duration for these tokens. But since these are +// non-spoken we collapse the timestamps so that they don't have an time duration. +static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector<parakeet_token_data> & tokens) { + if (tokens.empty()) { + return; + } + + int64_t last_non_punct_t1 = -1; + + for (size_t i = 0; i < tokens.size(); ++i) { + if (is_punctuation_token(vocab, tokens[i].id)) { + if (last_non_punct_t1 >= 0) { + tokens[i].t0 = last_non_punct_t1; + tokens[i].t1 = last_non_punct_t1; + } + } else { + last_non_punct_t1 = tokens[i].t1; + } + } +} + +static parakeet_token_data create_token_data( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_token token_id, + int duration_idx, + int duration_value, + int frame_index, + float token_logit, + int n_vocab_logits) { + + float token_sum = 0.0f; + for (int i = 0; i < n_vocab_logits; ++i) { + token_sum += expf(pstate.logits[i]); + } + float token_p = expf(token_logit) / token_sum; + + parakeet_token_data token_data; + token_data.id = token_id; + token_data.duration_idx = duration_idx; + token_data.duration_value = duration_value; + token_data.frame_index = frame_index; + token_data.p = token_p; + token_data.plog = token_logit; + token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor; + token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor; + token_data.is_word_start = is_word_start_token(pctx.vocab, token_id); + + return token_data; +} + +static bool parakeet_decode( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int n_tdt_durations = hparams.n_tdt_durations; + const int n_frames = pstate.n_frames; + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = 0; + // number of symbols emitted for the current time frame + int tokens_emitted = 0; + + // Start with the blank token (8192) + parakeet_token last_token = blank_id; + + PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames); + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 1; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network below. + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // process all time frames of the encoder output + while (t < n_frames) { + batch.n_tokens = 1; + batch.i_time[0] = t; + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. batch.i_time + // is used in to select the correct frame from the encoder output. + // The joint network outputs logits for all the tokens in the vocabulary + // plus the blank token, and also n_duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // find the best token (greedy). + // TODO: implement beam search? + int best_token = 0; + float max_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > max_logit) { + max_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < n_tdt_durations; ++i) { + if (pstate.logits[n_vocab_logits + i] > best_duration_logit) { + best_duration_logit = pstate.logits[n_vocab_logits + i]; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + // skip forward by duration time frames. + t += duration; + // reset symbols emitted counter + tokens_emitted = 0; + // continue without predicting. + continue; + } + + // Emit non-blank token at current frame t. + pstate.decoded_tokens.push_back(best_token); + pstate.t_sample_us += ggml_time_us() - t_start_sample_us; + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, best_token, best_duration_idx, duration, t, + max_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // Call token callback if registered (for real-time streaming) + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + last_token = best_token; + + // advance predictor for the non-blank token. + batch.token[0] = last_token; + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // if duration greater than 0, continue looping over the encoder frames + // and skip to the updated time frame (t). + if (duration > 0) { + t += duration; + tokens_emitted = 0; + continue; + } + + // if duration is zero we stay on the current time frame. + tokens_emitted++; + if (tokens_emitted >= max_tokens_per_timestep) { + t += 1; // forced blank/time advance behavior + tokens_emitted = 0; + } + } + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) { + const int sin_cos_step = cache.n_fft / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N + re += in[n]*cache.cos_vals[idx]; // cos(t) + im -= in[n]*cache.sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) { + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out, cache); + return; + } + + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; + } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft, cache); + + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft, cache); + + const int sin_cos_step = cache.n_fft / N; + for (int k = 0; k < half_N; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; // cos(t) + float im = -cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +struct mel_worker_params { + int ith; + int window_size; + int n_samples; + int frame_size; + int frame_step; + int n_threads; +}; + +static void log_mel_spectrogram_worker_thread( + mel_worker_params params, + const float * window_func, + const std::vector<float> & samples, + const parakeet_filters & filters, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + std::vector<float> fft_in(params.frame_size * 2, 0.0); + std::vector<float> fft_out(params.frame_size * 2 * 2 * 2); + + int n_fb = filters.n_fb; // number of frequency bins + int i = params.ith; + + // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist + assert(n_fb == 1 + (params.frame_size / 2)); + + const double eps = 5.960464477539063e-08; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) { + const int offset = i * params.frame_step; + + const int window_pad_left = (params.frame_size - params.window_size) / 2; + + // Zero-pad left + std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); + + // Apply windowed samples in the center + const int n_to_process = std::min({params.window_size, params.n_samples - offset}); + for (int j = 0; j < n_to_process; j++) { + fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; + } + + // Zero-pad right (and any samples we didn't have) + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f); + + // FFT + fft(fft_in.data(), params.frame_size, fft_out.data(), cache); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fb; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fb - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fb + k + 0] + + fft_out[k + 1] * filters.data[j * n_fb + k + 1] + + fft_out[k + 2] * filters.data[j * n_fb + k + 2] + + fft_out[k + 3] * filters.data[j * n_fb + k + 3]; + } + // handle n_fb remainder + for (; k < n_fb; k++) { + sum += fft_out[k] * filters.data[j * n_fb + k]; + } + + mel.data[i * mel.n_mel + j] = std::log(sum + eps); + } + } + + // Otherwise fft_out are all zero - use log(eps) for consistency + const double empty_sum = std::log(eps); + for (; i < mel.n_len; i += params.n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[i * mel.n_mel + j] = empty_sum; + } + } +} + +static bool log_mel_spectrogram( + parakeet_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const parakeet_filters & filters, + const bool debug, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + const int64_t t_start_us = ggml_time_us(); + + const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data(); + const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size(); + + std::vector<float> samples_preprocessed(samples, samples + n_samples); + + // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1] + { + const float preemph = 0.97f; + for (int i = n_samples - 1; i > 0; i--) { + samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1]; + } + } + + // Parakeet Pytorch implementation uses centered contant padding. + const size_t pad = (size_t)(frame_size / 2); + std::vector<float> samples_padded(n_samples + 2 * pad, 0.0f); + std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); + + mel.n_mel = n_mel; + mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1; + mel.n_len_org = mel.n_len; + mel.data.resize(mel.n_mel * mel.n_len); + + // Worker Threads (STFT + Mel + Natural Log) + { + std::vector<std::thread> workers(n_threads - 1); + const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads }; + + for (int iw = 0; iw < n_threads - 1; ++iw) { + mel_worker_params params = mel_params; + params.ith = iw + 1; + workers[iw] = std::thread(log_mel_spectrogram_worker_thread, + params, + window_func, + std::cref(samples_padded), + std::cref(filters), + std::ref(mel), + std::cref(cache)); + } + + log_mel_spectrogram_worker_thread( + mel_params, + window_func, + samples_padded, + filters, + mel, + cache); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + { + const double eps = 1e-5; + int valid_frames = n_samples / frame_step; + + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + double sq_diff_sum = 0.0; + + // Calculate Mean ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + sum += (double)mel.data[i * mel.n_mel + j]; + } + double mean = sum / valid_frames; + + // Calculate Variance ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + double diff = (double)mel.data[i * mel.n_mel + j] - mean; + sq_diff_sum += diff * diff; + } + + double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0)); + double denominator = std_dev + eps; + + // Apply to ALL frames (including the padded ones) + for (int i = 0; i < mel.n_len; i++) { + mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator); + } + } + } + + wstate.t_mel_us += ggml_time_us() - t_start_us; + + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +static std::vector<parakeet_vocab::id> tokenize(const parakeet_vocab & vocab, const std::string & text) { + std::vector<parakeet_vocab::id> tokens; + const std::string normalized = sentencepiece_normalize(text); + + size_t i = 0; + while (i < normalized.size()) { + const size_t remaining = normalized.size() - i; + const size_t max_len = std::min(vocab.max_token_length, remaining); + + bool found = false; + for (size_t len = max_len; len > 0; --len) { + const auto it = vocab.token_to_id.find(normalized.substr(i, len)); + if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) { + tokens.push_back(it->second); + i += len; + found = true; + break; + } + } + + if (!found) { + if (vocab.token_unk >= 0) { + tokens.push_back(vocab.token_unk); + } + + const unsigned char c = static_cast<unsigned char>(normalized[i]); + i += utf8_codepoint_len(c); + } + } + + return tokens; +} + + +// +// interface implementation +// + +struct parakeet_state * parakeet_init_state(parakeet_context * ctx) { + parakeet_state * state = new parakeet_state; + + state->backends = parakeet_backend_init(ctx->params); + if (state->backends.empty()) { + PARAKEET_LOG_ERROR("%s: parakeet_backend_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const int batch_size = ctx->model.hparams.n_audio_ctx; + + state->logits.reserve(ctx->vocab.n_vocab * batch_size); + + state->batch = parakeet_batch_init(batch_size); + + { + const int n_audio_state = ctx->model.hparams.n_audio_state; + const int subsampl_factor = ctx->model.hparams.subsampling_factor; + const int n_frames_max = (batch_size + subsampl_factor - 1) / subsampl_factor; + + if (!parakeet_enc_state_init(*state, state->backends[0], n_audio_state, n_frames_max)) { + PARAKEET_LOG_ERROR("%s: parakeet_enc_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const size_t mem_enc_ctx = state->enc_out_buf.size(); + const size_t mem_enc_out_buf = ggml_backend_buffer_get_size(state->enc_out_buffer); + PARAKEET_LOG_INFO("%s: enc_out state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_enc_ctx / 1024.0 / 1024.0, mem_enc_out_buf / 1024.0 / 1024.0); + } + + // conv/encoder allocator + bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends, + [&]() { + return parakeet_build_graph_encode(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init encode allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + state->sched_encode_n_audio_ctx = state->n_audio_ctx > 0 ? state->n_audio_ctx : ctx->model.hparams.n_audio_ctx; + + if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers, ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size(); + const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer); + PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0); + } + + if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_pred_ctx = state->pred_out_buf.size(); + const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer); + PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0); + } + + PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6); + + { + bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends, + [&]() { + const auto & hparams = ctx->model.hparams; + const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet + + parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0); + + return parakeet_build_graph_prediction(*ctx, *state, state->batch, true); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6); + } + + return state; +} + +struct parakeet_context_params parakeet_context_default_params() { + struct parakeet_context_params result = { + /*.use_gpu =*/ true, + /*.gpu_device =*/ 0, + }; + return result; +} + +struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) { + PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert<std::codecvt_utf8<wchar_t>> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + parakeet_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = parakeet_init_with_params_no_state(&loader, params); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 }; + + PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__); + + parakeet_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return parakeet_init_with_params_no_state(&loader, params); +} + +struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + ggml_time_init(); + + PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); + PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); + + parakeet_context * ctx = new parakeet_context; + ctx->params = params; + + bool model_loaded = false; + try { + model_loaded = parakeet_model_load(loader, *ctx); + } catch (const std::exception & e) { + PARAKEET_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + PARAKEET_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { + loader->close(loader->context); + PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + // Initialize mel cache with model's FFT size + ctx->mel_cache.init(ctx->model.hparams.n_fft); + PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft); + + return ctx; +} + +struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +void parakeet_free_state(struct parakeet_state * state) { + if (state) { + ggml_backend_buffer_free(state->lstm_state.buffer); + ggml_backend_buffer_free(state->pred_out_buffer); + ggml_backend_buffer_free(state->enc_out_buffer); + + parakeet_batch_free(state->batch); + + parakeet_sched_free(state->sched_encode); + parakeet_sched_free(state->sched_decode); + + for (auto & backend : state->backends) { + ggml_backend_free(backend); + } + + delete state; + } +} + +void parakeet_free(struct parakeet_context * ctx) { + if (ctx) { + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } + + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } + + parakeet_free_state(ctx->state); + + delete ctx; + } +} + +void parakeet_free_context_params(struct parakeet_context_params * params) { + if (params) { + delete params; + } +} + +void parakeet_free_params(struct parakeet_full_params * params) { + if (params) { + delete params; + } +} + +int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, + samples, + n_samples, + PARAKEET_SAMPLE_RATE, + ctx->model.hparams.n_fft, + PARAKEET_HOP_LENGTH, + ctx->model.filters.n_mel, + n_threads, + ctx->model.filters, + false, // debug + state->mel, + ctx->mel_cache)) { + PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) { + return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != ctx->model.filters.n_mel) { + PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel) { + return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -(int) res.size(); + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int parakeet_token_count(struct parakeet_context * ctx, const char * text) { + return -parakeet_tokenize(ctx, text, NULL, 0); +} + +int parakeet_model_n_vocab(struct parakeet_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int parakeet_model_n_audio_state(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int parakeet_model_n_audio_head(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int parakeet_model_n_audio_layer(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int parakeet_model_n_mels(struct parakeet_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int parakeet_model_ftype(struct parakeet_context * ctx) { + return ctx->model.hparams.ftype; +} + +int parakeet_n_len_from_state(struct parakeet_state * state) { + return state->mel.n_len_org; +} + +int parakeet_n_len(struct parakeet_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int parakeet_n_vocab(struct parakeet_context * ctx) { + return ctx->vocab.n_vocab; +} + +int parakeet_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +float * parakeet_get_logits(struct parakeet_context * ctx) { + return ctx->state->logits.data(); +} + +float * parakeet_get_logits_from_state(struct parakeet_state * state) { + return state->logits.data(); +} + +const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) { + std::string text = sentencepiece_piece_to_text(token_str, is_first); + + if (output == nullptr) { + return text.size(); + } + + int bytes_to_copy = std::min((int)text.size(), max_len - 1); + if (bytes_to_copy > 0) { + memcpy(output, text.c_str(), bytes_to_copy); + output[bytes_to_copy] = '\0'; + } else if (max_len > 0) { + output[0] = '\0'; + } + + return text.size(); +} + +parakeet_token parakeet_token_bos(struct parakeet_context * ctx) { + return ctx->vocab.token_bos; +} + +parakeet_token parakeet_token_unk(struct parakeet_context * ctx) { + return ctx->vocab.token_unk; +} + +parakeet_token parakeet_token_blank(struct parakeet_context * ctx) { + return ctx->vocab.token_blank; +} + +struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) { + if (ctx->state == nullptr) { + return nullptr; + } + parakeet_timings * timings = new parakeet_timings; + timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); + timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); + timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); + return timings; +} + +void parakeet_print_timings(struct parakeet_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + PARAKEET_LOG_INFO("\n"); + PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_predict = std::max(1, ctx->state->n_predict); + + PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + PARAKEET_LOG_INFO("%s: predict time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_us, n_predict, 1e-3f * ctx->state->t_predict_us / n_predict); + PARAKEET_LOG_INFO("%s: - build = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_build_us, n_predict, 1e-3f * ctx->state->t_predict_build_us / n_predict); + PARAKEET_LOG_INFO("%s: - alloc = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_alloc_us, n_predict, 1e-3f * ctx->state->t_predict_alloc_us / n_predict); + PARAKEET_LOG_INFO("%s: - compute = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_compute_us, n_predict, 1e-3f * ctx->state->t_predict_compute_us / n_predict); + + } + PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void parakeet_reset_timings(struct parakeet_context * ctx) { + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + ctx->state->t_predict_us = 0; + ctx->state->t_predict_build_us = 0; + ctx->state->t_predict_alloc_us = 0; + ctx->state->t_predict_compute_us = 0; + + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_predict = 0; + } +} + +const char * parakeet_print_system_info(void) { + static std::string s; + + s = ""; + s += "PARAKEET : "; + + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (get_features_fn) { + ggml_backend_feature * features = get_features_fn(reg); + s += ggml_backend_reg_name(reg); + s += " : "; + for (; features->name; features++) { + s += features->name; + s += " = "; + s += features->value; + s += " | "; + } + } + } + return s.c_str(); +} + +struct parakeet_context_params * parakeet_context_default_params_by_ref(void) { + struct parakeet_context_params params = parakeet_context_default_params(); + + struct parakeet_context_params* result = new parakeet_context_params(); + *result = params; + return result; +} + +struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params params = parakeet_full_default_params(strategy); + + struct parakeet_full_params* result = new parakeet_full_params(); + *result = params; + return result; +} + +struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params result = { + /*.strategy =*/ strategy, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + /*.no_context =*/ true, + /*.audio_ctx =*/ 0, + /*.new_token_callback =*/ nullptr, + /*.new_token_callback_user_data =*/ nullptr, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + }; + + return result; +} + +static void parakeet_reset_state(struct parakeet_state * state) { + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + + if (state->lstm_state.buffer) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + } + +} + +// Encode and decode the mel spectrogram already in state, without recomputing it. +static int parakeet_chunk_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params) { + return parakeet_chunk(ctx, state, params, nullptr, 0); +} + +int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + state->result_all.clear(); + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + const int n_mel_total = state->mel.n_len; + const int n_audio_ctx = ctx->model.hparams.n_audio_ctx; + + if (n_mel_total <= n_audio_ctx) { + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + return parakeet_chunk_with_state(ctx, state, params); + } + + PARAKEET_LOG_DEBUG("%s: audio too long (%d mel > n_audio_ctx=%d), using dynamic encoder graph\n", + __func__, n_mel_total, n_audio_ctx); + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false\n", __func__); + return -6; + } + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + + if (!parakeet_ensure_encode_sched(*ctx, *state, n_mel_total)) { + PARAKEET_LOG_ERROR("%s: failed to allocate dynamic encoder graph for %d mel frames\n", + __func__, n_mel_total); + return -6; + } + + state->n_audio_ctx = n_mel_total; + + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 100, params.progress_callback_user_data); + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * tok_str = parakeet_token_to_str(ctx, token_id); + if (tok_str) { + const bool is_first = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(tok_str, is_first); + } + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment seg; + seg.t0 = 0; + seg.t1 = state->n_frames; + seg.text = text; + seg.tokens = result_tokens; + state->result_all.push_back(std::move(seg)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + if (params.audio_ctx == 0) { + const int total_len = parakeet_n_len_from_state(state); + const int model_max_ctx = parakeet_n_audio_ctx(ctx); + params.audio_ctx = std::min(total_len, model_max_ctx); + PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx); + } + state->n_audio_ctx = params.audio_ctx; + + const int n_frames = parakeet_n_len_from_state(state); + + if (!parakeet_ensure_encode_sched(*ctx, *state, state->n_audio_ctx)) { + PARAKEET_LOG_ERROR("%s: failed to allocate encoder graph for %d mel frames\n", + __func__, state->n_audio_ctx); + return -6; + } + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + return -6; + } + } + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + // Use the stored token data from parakeet_decode + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = 0; // Caller tracks timing + segment.t1 = n_frames; + segment.text = text; + segment.tokens = result_tokens; + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full_n_segments_from_state(struct parakeet_state * state) { + return state->result_all.size(); +} + +int parakeet_full_n_segments(struct parakeet_context * ctx) { + return ctx->state->result_all.size(); +} + +int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment); +} + +int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment); +} + +const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +void parakeet_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default; + g_state.log_callback_user_data = user_data; + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); +} + +const char * parakeet_version(void) { + return PARAKEET_VERSION; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; +#ifndef PARAKEET_DEBUG + if (level == GGML_LOG_LEVEL_DEBUG) { + return; + } +#endif + fputs(text, stderr); + fflush(stderr); +} diff --git a/src/whisper.cpp b/src/whisper.cpp index b39f205307e..e403e6dc0dc 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -3811,7 +3811,21 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ whisper_context * ctx = new whisper_context; ctx->params = params; - if (!whisper_model_load(loader, *ctx)) { + // A C++ exception escaping this extern "C" function aborts non-C++ callers + // (Rust via whisper-rs, Go via cgo, ...). whisper_model_load can throw + // (std::runtime_error here; vk::SystemError from the Vulkan backend during + // device/buffer allocation), so funnel any throw into the existing + // NULL-return failure path instead of letting it cross the C ABI. + bool model_loaded = false; + try { + model_loaded = whisper_model_load(loader, *ctx); + } catch (const std::exception & e) { + WHISPER_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + WHISPER_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { loader->close(loader->context); WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6b80b023ffb..9405113c2c4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -78,7 +78,7 @@ add_test(NAME ${TEST_TARGET} -f ${PROJECT_SOURCE_DIR}/samples/jfk.wav) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "large") -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) set(TEST_TARGET test-whisper-cli-tiny-mp3) # Check with reviewers: any way to check the output transcription via ctest (diff, ...)? add_test(NAME ${TEST_TARGET} @@ -88,6 +88,14 @@ if (WHISPER_FFMPEG) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") endif() +# UTF-8 helper unit test +set(UTF8_TEST test-common-utf8) +add_executable(${UTF8_TEST} ${UTF8_TEST}.cpp) +target_include_directories(${UTF8_TEST} PRIVATE ../examples) +target_link_libraries(${UTF8_TEST} PRIVATE common) +add_test(NAME ${UTF8_TEST} COMMAND ${UTF8_TEST}) +set_tests_properties(${UTF8_TEST} PROPERTIES LABELS "unit") + # VAD test tests VAD in isolation set(VAD_TEST test-vad) add_executable(${VAD_TEST} ${VAD_TEST}.cpp) @@ -125,3 +133,62 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "unit;vad") + +# Parakeet model loading test +set(PARAKEET_TEST test-parakeet) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-ggml-parakeet-tdt.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) +set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;gh") + +# The following parakeet test require a real ggml-parakeet-tdt model to have +# been converted or downloaded: +# $ hf download danbev/parakeet parakeet-tdt-0.6b-v3-f32.bin --local-dir models +# +# And also required more audio samples that are shipped by default. These can +# downloaded by running: +# $ make samples +function(add_parakeet_transcription_test TEST_TARGET TEST_SOURCE SAMPLE_PATH EXPECTED_TRANSCRIPTION_PATH) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "1.0") + if (ARGC GREATER 4) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "${ARGV4}") + endif() + + add_executable(${TEST_TARGET} ${TEST_SOURCE}) + target_include_directories(${TEST_TARGET} PRIVATE ../include ../ggml/include ../examples) + target_link_libraries(${TEST_TARGET} PRIVATE parakeet common) + target_compile_definitions(${TEST_TARGET} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3-f32.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/${SAMPLE_PATH}" + EXPECTED_TRANSCRIPTION_PATH="${PROJECT_SOURCE_DIR}/${EXPECTED_TRANSCRIPTION_PATH}" + TRANSCRIPTION_SIMILARITY_THRESHOLD=${TRANSCRIPTION_SIMILARITY_THRESHOLD}) + + add_custom_target(run-${TEST_TARGET} + COMMAND $<TARGET_FILE:${TEST_TARGET}> + DEPENDS ${TEST_TARGET} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) +endfunction() + +add_parakeet_transcription_test( + test-parakeet-full-jfk + test-parakeet-full.cpp + samples/jfk.wav + tests/parakeet-expected-jfk-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-gb1 + test-parakeet-full.cpp + samples/gb1.wav + tests/parakeet-expected-gb1-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-diffusion + test-parakeet-full.cpp + samples/diffusion2023-07-03.flac + tests/parakeet-expected-diffusion-output.txt + 0.95) + diff --git a/tests/librispeech-parakeet/.gitignore b/tests/librispeech-parakeet/.gitignore new file mode 100644 index 00000000000..838bfeae9db --- /dev/null +++ b/tests/librispeech-parakeet/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.tar.gz +*.txt +eval.conf +venv +LibriSpeech diff --git a/tests/librispeech-parakeet/Makefile b/tests/librispeech-parakeet/Makefile new file mode 100644 index 00000000000..0afa2465f49 --- /dev/null +++ b/tests/librispeech-parakeet/Makefile @@ -0,0 +1,15 @@ +TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz + +all: eval + +eval: + $(MAKE) -f eval.mk + +clean: + $(MAKE) -f eval.mk clean + +get-audio: + wget -c $(TAR_URL) + tar -xf test-clean.tar.gz + +.PHONY: all eval clean setup-venv clean-venv get-audio diff --git a/tests/librispeech-parakeet/README.md b/tests/librispeech-parakeet/README.md new file mode 100644 index 00000000000..e09cba405ef --- /dev/null +++ b/tests/librispeech-parakeet/README.md @@ -0,0 +1,57 @@ +# parakeet.cpp/tests/librispeech + +[LibriSpeech](https://www.openslr.org/12) is a standard dataset for +training and evaluating automatic speech recognition systems. + +This directory contains a set of tools to evaluate the recognition +performance of parakeet.cpp on LibriSpeech corpus. + +## Quick Start + +1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet + model in `ggml` format. + + ``` + $ # Execute the commands below in the project root dir. + $ cmake -B build + $ cmake --build build --config Release + ``` + +2. Download the audio files from LibriSpeech project. + + ``` + $ make get-audio + ``` + +3. Set up the environment to compute WER score. + + ``` + $ pip install -r requirements.txt + ``` + + For example, if you use `virtualenv`, you can set up it as follows: + + ``` + $ python3 -m venv venv + $ . venv/bin/activate + $ pip install -r requirements.txt + ``` + +4. Run the benchmark test. + + ``` + $ make + ``` + +## How-to guides + +### How to change the inference parameters + +Create `eval.conf` and override variables. + +``` +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 +PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt +``` + +Check out `eval.mk` for more details. diff --git a/tests/librispeech-parakeet/eval.mk b/tests/librispeech-parakeet/eval.mk new file mode 100644 index 00000000000..7d8992ec471 --- /dev/null +++ b/tests/librispeech-parakeet/eval.mk @@ -0,0 +1,39 @@ +PYTHON = python + +PARAKEET_PREFIX = ../../ +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 + +PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli +PARAKEET_FLAGS = --no-prints --output-txt + +# You can create eval.conf to override the PARAKEET_* variables +# defined above. +-include eval.conf + +# This follows the file structure of the LibriSpeech project. +AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac)) +TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS)) + +# We output the evaluation result to this file. +DONE = $(PARAKEET_MODEL).txt + +all: $(DONE) + +$(DONE): $(TRANS_TXTS) + $(PYTHON) eval.py > $@.tmp + mv $@.tmp $@ + +# Note: This task writes to a temporary file first to +# create the target file atomically. +%.flac.txt: %.flac + $(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp + mv $^.tmp.txt $^.txt + +archive: + tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE) + +clean: + @rm -f $(TRANS_TXTS) + @rm -f $(DONE) + +.PHONY: all clean diff --git a/tests/librispeech-parakeet/eval.py b/tests/librispeech-parakeet/eval.py new file mode 100644 index 00000000000..cdaf8352fd4 --- /dev/null +++ b/tests/librispeech-parakeet/eval.py @@ -0,0 +1,47 @@ +import os +import glob +import jiwer +from normalizers import EnglishTextNormalizer + +def get_reference(): + ref = {} + for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'): + with open(path) as fp: + for line in fp: + code, text = line.strip().split(" ", maxsplit=1) + ref [code] = text + return ref + +def get_hypothesis(): + hyp = {} + for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'): + with open(path) as fp: + text = fp.read().strip() + code = os.path.basename(path).replace('.flac.txt', '') + hyp[code] = text + return hyp + +def get_codes(): + codes = [] + for path in glob.glob('LibriSpeech/*/*/*/*.flac'): + codes.append(os.path.basename(path).replace('.flac', '')) + return sorted(codes) + +def main(): + normalizer = EnglishTextNormalizer() + + ref_orig = get_reference() + hyp_orig = get_hypothesis() + + ref_clean = [] + hyp_clean = [] + + for code in get_codes(): + ref_clean.append(normalizer(ref_orig[code])) + hyp_clean.append(normalizer(hyp_orig[code])) + + wer = jiwer.wer(ref_clean, hyp_clean) + print(f"WER: {wer * 100:.2f}%") + +if __name__ == '__main__': + main() diff --git a/tests/librispeech-parakeet/normalizers/LICENSE b/tests/librispeech-parakeet/normalizers/LICENSE new file mode 100644 index 00000000000..7c8e603b0fc --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/LICENSE @@ -0,0 +1,25 @@ +Code in this directory is adapted from OpenAI Whisper project +(https://github.com/openai/whisper) and carries the following +copyright and license. + + MIT License + + Copyright (c) 2022 OpenAI + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/tests/librispeech-parakeet/normalizers/__init__.py b/tests/librispeech-parakeet/normalizers/__init__.py new file mode 100644 index 00000000000..896d5e33641 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/__init__.py @@ -0,0 +1,2 @@ +from .basic import BasicTextNormalizer as BasicTextNormalizer +from .english import EnglishTextNormalizer as EnglishTextNormalizer diff --git a/tests/librispeech-parakeet/normalizers/basic.py b/tests/librispeech-parakeet/normalizers/basic.py new file mode 100644 index 00000000000..8690ae71c5f --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/basic.py @@ -0,0 +1,80 @@ +import re +import unicodedata + +import regex + +# non-ASCII letters that are not separated by "NFKD" normalization +ADDITIONAL_DIACRITICS = { + "œ": "oe", + "Œ": "OE", + "ø": "o", + "Ø": "O", + "æ": "ae", + "Æ": "AE", + "ß": "ss", + "ẞ": "SS", + "đ": "d", + "Đ": "D", + "ð": "d", + "Ð": "D", + "þ": "th", + "Þ": "th", + "ł": "l", + "Ł": "L", +} + + +def remove_symbols_and_diacritics(s: str, keep=""): + """ + Replace any other markers, symbols, and punctuations with a space, + and drop any diacritics (category 'Mn' and some manual mappings) + """ + return "".join( + ( + c + if c in keep + else ( + ADDITIONAL_DIACRITICS[c] + if c in ADDITIONAL_DIACRITICS + else ( + "" + if unicodedata.category(c) == "Mn" + else " " if unicodedata.category(c)[0] in "MSP" else c + ) + ) + ) + for c in unicodedata.normalize("NFKD", s) + ) + + +def remove_symbols(s: str): + """ + Replace any other markers, symbols, punctuations with a space, keeping diacritics + """ + return "".join( + " " if unicodedata.category(c)[0] in "MSP" else c + for c in unicodedata.normalize("NFKC", s) + ) + + +class BasicTextNormalizer: + def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): + self.clean = ( + remove_symbols_and_diacritics if remove_diacritics else remove_symbols + ) + self.split_letters = split_letters + + def __call__(self, s: str): + s = s.lower() + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space + + return s diff --git a/tests/librispeech-parakeet/normalizers/english.json b/tests/librispeech-parakeet/normalizers/english.json new file mode 100644 index 00000000000..74a1c3521d9 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.json @@ -0,0 +1,1741 @@ +{ + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archeological": "archaeological", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology</span>", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "edoema": "edema", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "glamor": "glamour", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipping": "worshiping", + "worshipper": "worshiper", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts", + "mhm": "hmm", + "mmm": "hmm" +} \ No newline at end of file diff --git a/tests/librispeech-parakeet/normalizers/english.py b/tests/librispeech-parakeet/normalizers/english.py new file mode 100644 index 00000000000..4932042bc5b --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.py @@ -0,0 +1,550 @@ +import json +import os +import re +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union + +from more_itertools import windowed + +from .basic import remove_symbols_and_diacritics + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + self.ones = { + name: i + for i, name in enumerate( + [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ], + start=1, + ) + } + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") + for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = { + name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() + } + self.tens_ordinal = { + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() + } + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = { + name + "s": (value, "s") for name, value in self.multipliers.items() + } + self.multipliers_ordinal = { + name + "th": (value, "th") for name, value in self.multipliers.items() + } + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "£", + "pounds": "£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "¢", + "cents": "¢", + } + self.prefixes = set( + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) + ) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = set( + [ + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + ] + ) + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for prev, current, next in windowed([None] + words + [None], 3): + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + assert f is not None + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit + assert value[-1] == "0" + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + assert value[-1] == "0" + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace "<number> and a half" with "<number> point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and ¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self): + mapping_path = os.path.join(os.path.dirname(__file__), "english.json") + self.mapping = json.load(open(mapping_path)) + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer() + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + + return s diff --git a/tests/parakeet-expected-diffusion-output.txt b/tests/parakeet-expected-diffusion-output.txt new file mode 100644 index 00000000000..9753a86953a --- /dev/null +++ b/tests/parakeet-expected-diffusion-output.txt @@ -0,0 +1 @@ +Hello and welcome to Diffusion. Sit back and relax while we stretch your brain with weird and wonderful science. I'm Ian Wolf. On this edition, Dr. Viv Robinson rewrites cosmology. But first up, here's news of two massive galaxies that might be older than the Big Bang. Galaxies too massive. Astronomers from the Swinburne University of Technology in Melbourne, using the James Webb Space Telescope, have observed six galaxies that formed in the universe's first 700 million years appear to be up to a hundred times more massive than our best theories say can possibly exist. Astronomer Ivo Labe and his colleagues wrote in his paper, adding up the stars in those galaxies, it would exceed the total amount of mass available in the universe at that time. There's too much mass and not enough time for it to get together. The galaxies must have had much longer than the 700 million years after the Big Bang that our standard model of the universe gives them, and the universe must have had more mass available, or galaxies must have formed differently than what we think. The Big Bang is currently thought to have started everything 13.77 billion years ago. And these galaxies, we're watching them at 0.77 billion years ago because they're so far away. Galaxies are thought to accumulate gas moved together by giant clumps of dark matter in their region. Generally, only about 10% of the gas in the galaxy ignites to make a star. For galaxies in the remotest parts of the universe where the gas is thin, it takes a long time to accumulate this much gas for this many stars. These six galaxies, however, have so many stars adding up to so much mass that all of the gas in each galaxy had to have become 100% converted into stars in the 700 million years since the universe started in the Big Bang. Under our current understanding, this is impossible. It suggests something in our understanding of the cosmos is wrong. Are we wrong about how to calculate astronomical masses, galaxy formation, dark matter, and the Big Bang and the age of the universe? An astronomer from the Cosmic Dawn Centre in Denmark used the James Webb telescope to look at closer galaxies, and then used the very high resolution of that telescope to calculate the mass more precisely with a different method, and found that these galaxies are three to ten times more massive than we previously thought. Applying this more accurate technique to the six galaxies that are 13 billion light years away would increase their mass, which makes it much worse than what we thought. The paper was titled A Population of Red Candidate Massive Galaxies, approximately 600 million years after the Big Bang, and was published in the journal Nature.com. We're brought to you across Australia on the Community Radio Network and podcast over the internet on www.diffusionradio.com Challenging Physics Newton said everything is either a particle or a wave. Faraday and Maxwell added fields. Einstein added space-time. Quantum physics says everything is made of quanta, which have the properties of both waves and particles, but is neither. Quantum mechanics has no explanation for gravity, and relativity doesn't account for the quantum world. There's a contradiction between our most basic explanations of the universe. Dr. Viv Robinson was the first person to create a physical explanation of Einstein's gravity in a paper published in the Journal of Physics Communications. He's made corrections to people's extensions of Einstein's mathematics and has a different way to interpret those mathematics that gives a different picture of the age of the universe and a different way of looking at how the physics works. From the standard model of quantum physics to Big Bang cosmology. Everything, including you and me, is made of light. It's a very big and very bold claim. I spoke to Dr. Viv Robinson via Zoom and began by asking him, what is the universe made of? The whole stuff of the universe, or entity. I won't call it items because one of them is absolutely nothing. The first thing to all the mass and all the energy is made up of photons. They're little packets of electromagnetic energy, postulated by Maxwell and Planck and proven by Einstein. They come in many different sizes, shapes, and which make that they make up all the mass and energy of the universe. The volume is made up by empty space, absolutely nothing. But it's the properties of the space that are important. And it does this through two of its properties, electric permittivity and magnetic permeability. And it's those properties which then transmit all of the fields. So that's really all it is. They're just the only two stars in a call because the photos are physical things, and space is just the absence of everything, but its property, its properties are what is important about it. And that's a little bit different to what you might hear from a quantum physics class where they talk about space being full of virtual particles coming into and out of existence so that it's not totally empty, or sometimes they say it's full of fields. The fields of every force is in there and things are coming up all the time. So if you go very fast, you'll interact with the fields, all the virtual particles, and you'll get radiation. Yes, well, uh the unfortunate part is that physics is doing exceedingly well under Newtonian mechanics and exceedingly well under Maxwell's mechanics. But as things get smaller and smaller, you get to a stage where things aren't continuous. I mean, Newton's work will anything that's continuous, but eventually you get to the stage where you know a droplet of water is fine, it has surface tension, evaporates, and you're left with one molecule of water. That doesn't behave the same as bulk water. Into that molecule you go hydrogen atoms and oxygen atoms, they behave nothing like water. And then you get, well, they're made of protons, neutrons, electrons, and they have completely different properties from bulk water. So quantum mechanics, things get quantized, and you get the smallest quantity you can get, and that has very, very different properties from the bulk. And what has happened in the past is that uh the uh early on in quantum mechanics and met men like Dirac and Schrdinger, they didn't know what an the structure was an electron was. Also, all they had to know, they knew it was it had wave properties. And so all they did was they attributed it to a way a wave property to it. Now, waves have the advantage over particles, you can manipulate them almost forever with all sorts of different transforms until you get the answer you want. And that gave some confidence to quantum mechanics guys that yes, waves work, and they've been using that forever, and all I'm saying, no, no, no, no, no. Everything is particles, and the particles have specific properties, and you can't manipulate those properties, or you can to a certain extent, but they are what they are, and it's when you know what those properties are that the whole quantum mechanics becomes much simpler. You don't need any of that uh foamy sort of stuff to get to explain whatever you want to explain. I mentioned that there are many different forms of photons, and photons are electromagnetic radiation with an electric field, saying on a magnetic field perpendicular to it, and the whole lot travels in the speed of light in the third dimension. There are many, many variations of that. So that that's fine for energy radiation. But how about matter particles? Well, matter particles are nothing more than photons of the appropriate wavelength making uh appropriate energy making two revolutions per wavelength. And when they do that, what holds what allows them to do that is that they rotate around the magnetic field. And suddenly, instead of in a linear photon, magnetic fields are open. When they rotate around the magnetic field, then the magnetic field of a particle is closed. And a closed magnetic field is much more stable than an open magnetic field, and that's why most of the universe, for example, when uh less about, I think the best estimate I've seen, one percent is radiation, the other 99% is photons struggling in circles, making two revolutions per wavelength. And it's for that that gives particles all their properties. Now, I may say this is a bit hairy-fairy, but it's been known for a long, long time that you get a particle and an antiparticle, you put them together, bing, two photons. At the same time, you can get a photon and goes and hit the target, bang, a particle and an antiparticle. Now that shows a relationship between the two that somehow lots of people missed. But what's the simplest relationship you can have? The simplest relationship is that a particle is a photon making two revolutions in one direction, an antiparticle is the same particle making two revolutions in the other direction. Put them together, they unlock. Because they have mass, they have this thing called angular momentum, which is a great Newtonian property. But because mathematicians sort of didn't know what an electron was, they called it a point particle. You can't have angular momentum with a point particle, so they call it spin and they wave all sorts of different things to make it seem as if they know what they're talking about. It's really just angular momentum. And that's the relationship between mass and energy. Energy is the photon zipping along at the speed of light. Mass is the same photon making two revolutions per wavelength. That's how they can interchange so easily. And that property gives particles all of their properties, including mass. And one of the things that Einstein did work out in 1905, those little what they called uh packets of radio of electromagnetic energy, he did work out that they carried momentum or carried inertia, they had momentum, they had mass. I don't know why people want to prove Einstein wrong. Photons have mass. Now I think the reason for this is that they think oh, Einstein's special relativity corrections, anything traveling at the speed of light, will have an infinite mass. The special relativity corrections only apply to photons which are spiraling. And that's just as um the reason for that is about as complicated as uh post Thagoras' theorem. And what he was at 300 BC or something like that, not difficult. And so photons themselves always travel at the speed of light. And so the rotating photons, photons that are rotating, are rotating also at the same speed of light. Well, that's one old hell of a gyroscope. And that is what gives particles a spin, that's why E equals mc squared, and it's all straightforward. There you go. Really? Well. So if we go back a little bit there where you're saying there's no wave nature, what about the double-slit experiment and other sorts of experiments that seem to show wave properties of particles other than photons? Particles um De Royal worked out in 1925 that if if photons, if um photons behave like particles, and particles to behave like photons, I agree with him, it's completely it's completely true. The actual nature of the rotating photon generates the de Broilie wavelength, and it has all the right properties. For me, and to me, Einstein's special and general relativity theories are relatively simple, so it may I may be talking a little bit out of line here. But the deuil wavelength is automatically generated by the particle as it moves. So it's not something that they hypothesize and don't know what occurs. They they hypothesized it, they measured it, but they don't know how it occurs. Well, yeah, it's quite it's fairly straightforward, but not at uh not not not at this level. What are the implications for this difference in understanding? So are there predictions that you would make that are different to the ones that people following the standard model would make? Oh, not the numbers of them, yeah. So probably the electron tunneling. Where electrons hit a barrier. That's got a very simple mechanical analog. I mean, the electrons are held in uh what you call a very taut field. Now, if you've got something coming up, you've got everything in a tight situation, you come something up banging it at this end, you can do it with billiard balls that'll transport through, and another one will knock out. So, what they call tunneling under this model, but in reality, what they call tunneling is just really a momentum exchange. So that's a little bit like one of those Newton cradles. Where you've got the balls on all attached by a string or a chain to a fulcrum over the top, and one will hit the other one and transfer the momentum to the other one without actually transferring itself. Yeah, you don't get electrons, you know, they have they have wave properties, but yes, but you won't get an electron uh tunneling the wave, the wave is in a very fixed position with respect to the uh electron. It's equal on either side of it. If their tunneling theory were correct, then the lower the energy of the electron, the longer its wavelength, therefore the easier it would be to tunnel. However, in the energy transfer one, the higher the energy, the greater probability it'll knock another electron out the other side. Or it's a simple experiment to do. Just increase the energy of uh an electron coming up to a barrier and see which ones go come out the other end first. Is anyone set up to do that? Oh, anyone could set up to do it. Well, a lot of laboratories could do it. And the so-called tunneling effect is what they use in all of the microelectronics systems. And they wouldn't, it wouldn't, it'd be a very, very simple exercise to carry that out. They may well have done it, and the mathematicians have turned around and added another factor. Yeah, it's a standard thing they do when they don't get the right answer, just add another factor. I can't do that. It's physical reality is physical reality. End of story. I guess that's something to look up and see if someone's done those experiments and and what they did with the results. I think there is I think I'm sure it has been done, and the result is that the higher the energy of the electron, the greater the probability of it emerging on the other side of the barrier. And on the very much bigger scale, are there differences in the way the universe looks for astronomy? Yeah, not as far as astronomy is concerned. What the astronomers see is what there is. No question about it. They're great, they're brilliant, as the astronomers, and most of the experimentalists are they're doing an exceedingly good job. The problem becomes in interpreting what they've seen. And when it comes to the whole universe, for example, it's all based on Einstein's theory of gravity. Well, it should be, but it's more advanced than Newton's inverse square, but for most practical purposes, uh Newton's inverse square works quite well. The two situations where it doesn't work, when the mass is so large, like the mass of the sun or the mass of the center of uh Sagittarius A with the planet or star S2 going around it. That's one situation. The reason why a planet uh or Mercury's orbit precesses in its direction of travel is simply that gravity, when mass is strong enough, gravity actually becomes weaker than inverse square. And that's one of the things you get when you solve Einstein's gravity theory accurately. It becomes weaker than inverse square. Now, when it's weak, if it's weaker than inverse square, Mercury travels a little bit closer to the Sun and is attracted by a slightly stronger force. So it'll arrive back at its perihelion point a little later, and it it'll um process in its direction of travel. And Newton pointed that out in 1687. So I don't know why they didn't sort of work it out correctly. But gravity is weaker than inverse square, is the solution to Einstein's gravity. The other thing is that when gravity is an infinite steady state universe under Newton's theory of gravity, inverse square, will collapse. The reason being that the relative to the universe density mass increases as r cubed, gravity decreases as r squared, so eventually you get to the stage where gravity just uh dominates mass and it collapses. But if gravity is weaker than inverse square, and I just tried to show you that Mercury is precessing orbit because the sun's gravity is weaker than inverse square, well, that applies to all gravity. There's nothing special about our sun, except that it's keeping all us alive on this. When you have an infinite steady-state universe, if gravity is weaker than inverse squares, its effect gets relatively weaker over long distances. And I'm talking typically uh 10 billion light years or something like that, maybe more. But that means an infinite steady-state universe won't collapse. That's a huge, huge difference. That's the biggest thing, mind you, what difference does it make to us here on Earth if uh if Bang's web has seen galaxies, fully formed galaxies 20 billion light years away, doesn't make a scrap of difference to us. But as far as understanding how the universe works, that mistake, and the simple the simple mistake that they the um all mathematicians were uh made, Einstein introduced approximations. He couldn't solve the gravity exactly himself. I have no problem solving his uh his gravity exactly. But he he uh introduced the approximation that one over one plus x approximately equals one minus x. You know, when x is ten to the minus seven or which is or ten to the minus eight, that's a good approximation. I mean you you just read his paper, he says so. And you read the mathematics, you don't even you could read the German version, look at the mathematics, and he says so, and you just work it out, and that was the difference. So, all of their exact solutions to Einstein's gravity, they took where he used the approximation, he derived the figure from one plus one over x, the equivalent of that, and then he rather than do that, he equated it to one minus x, which is which is true. You know, one plus one millionth is nine hundred and ninety millionth. Why they did it, I have no idea. Mind you, it'd be interesting to try and find out why. Uh I think it's if a mathematician of repute says one thing, and I I I will agree that uh on my first readings of Einstein's relativity theories, you think, oh my god, really? Could he understand that then? Then you get in and you start. It's not that difficult. And I think most of them had a solution. You know, somebody came up with a solution to Einstein's group, and everybody just followed it. And nobody, and this is the big thing that I always stress to everybody, don't take somebody's word for it. Go back and check the original yourself. I've seen a few times where people have just made terrible, terrible mistakes. But this would probably be the biggest one in the whole field of cosmology, sorry. Astronomy? You guys, great. Thanks, Uncle Sam, for providing us with all this information. That was part one of my interview with Dr. Viv Robinson. You heard Viv say that matter is made of photons moving in circles. Physicists took Einstein's approximations as gospel instead of using the exact solutions available with lather mathematics. Gravity changes to be weaker over distances, and the universe isn't expanding. Listen next week for part two. If you have any questions for Dr. Robinson, he'd love to answer them on the show. So send your questions to science at diffusionradio.com. If you're in Darlinghurst this Wednesday night, the 5th of July, I will be part of the lineup of scientists speaking at Future Science Talks at the East Village. Go to www.futurescience talks.com.au to grab a ticket and come up and say hello. And if you can't make it Wednesday night, I'll keep you posted on some future talks I'll be giving. And that's all from us this week on Diffusion. Are you a scientist, artist, biohacker, or maker who'd like to be interviewed about your work? Would your company like to sponsor diffusion? Send your contributions, opinions, helpful suggestions and donations to science at diffusionradio.com. That's science at diffusionradio.com. Please subscribe to the Diffusion Science Radio channel on youtube.com slash C slash Diffusion Radio and rate the show on iTunes and tell your friends. Follow me on Twitter at IanWorf. The news music was Rhinos Theme by Kevin McLeod of Incompitech.com. I produce diffusion, which is broadcast around Australia, to 28 stations on the community radio network, including Radio Blue Mountains 89.1 FM in New South Wales, 8CCC in Alice Springs and Tennant Creek, 2 MVR in Nambucker Valley, 3 MVR in the Malleigh Border Districts of Victoria and South Australia, City Park Radio 7LTN in Launcest and Tasmania, and 2XFM in Canberra. Diffusion is narrowcast on Indigo FM88 in Northeast Victoria. Diffusion is syndicated globally on astronomy.fm. Subscribe to the podcast on the diffusion website www.diffusionradio.com. That's www.diffusionradio.com and check the website for links, photos, and videos about this week's show. If you enjoyed the show, you can explore more than a thousand previous episodes archived on diffusionradio.com where the shows are labelled by keywords so you can focus in on the stories you want to hear. Make a donation through PayPal.me slash Ian Worf. Or join my patrons at patreon.com slash Diffusion Radio. I'm Ian Worf. Join us inside your audio device of choice for more science wondering next week on Diffusion Science Radio. Science is fun. It helps you to learn, to know, and to appreciate. When you study science, you make fun feel. diff --git a/tests/parakeet-expected-gb1-output.txt b/tests/parakeet-expected-gb1-output.txt new file mode 100644 index 00000000000..312ed1ce048 --- /dev/null +++ b/tests/parakeet-expected-gb1-output.txt @@ -0,0 +1 @@ +My fellow Americans, this day has brought terrible news and great sadness to our country. At nine o'clock this morning, mission control in Houston lost contact with our space shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. On board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Kulpna Shavla, and Ilan Ramon, a colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. In an age when space flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the earth. These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. Yet farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens. Who created all these? He who brings out the starry hosts one by one and calls them each by name. Because of his great power and mighty strength, not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to Earth. Yet we can pray that all are safely home. May God bless the grieving families and make out may God continue to bless America. diff --git a/tests/parakeet-expected-jfk-output.txt b/tests/parakeet-expected-jfk-output.txt new file mode 100644 index 00000000000..ece35697ae8 --- /dev/null +++ b/tests/parakeet-expected-jfk-output.txt @@ -0,0 +1 @@ +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/tests/parakeet-verification.h b/tests/parakeet-verification.h new file mode 100644 index 00000000000..0e95610ba26 --- /dev/null +++ b/tests/parakeet-verification.h @@ -0,0 +1,110 @@ +#pragma once + +#include <algorithm> +#include <cassert> +#include <cctype> +#include <cstdio> +#include <fstream> +#include <iterator> +#include <string> +#include <vector> + +#ifndef TRANSCRIPTION_SIMILARITY_THRESHOLD +#define TRANSCRIPTION_SIMILARITY_THRESHOLD 1.0 +#endif + +static std::string read_expected_transcription(const char * path) { + std::ifstream fin(path); + assert(fin.is_open()); + + std::string text( + (std::istreambuf_iterator<char>(fin)), + std::istreambuf_iterator<char>()); + + while (!text.empty() && (text.back() == '\n' || text.back() == '\r')) { + text.pop_back(); + } + + return text; +} + +static std::vector<std::string> transcription_words(const std::string & text) { + std::vector<std::string> words; + std::string word; + + for (unsigned char ch : text) { + if (std::isalnum(ch)) { + word.push_back((char) std::tolower(ch)); + } else if (!word.empty()) { + words.push_back(word); + word.clear(); + } + } + + if (!word.empty()) { + words.push_back(word); + } + + return words; +} + +static double transcription_lcs_similarity(const std::string & expected, const std::string & actual) { + const std::vector<std::string> expected_words = transcription_words(expected); + const std::vector<std::string> actual_words = transcription_words(actual); + + if (expected_words.empty() && actual_words.empty()) { + return 1.0; + } + + if (expected_words.empty() || actual_words.empty()) { + return 0.0; + } + + std::vector<int> prev(actual_words.size() + 1, 0); + std::vector<int> cur (actual_words.size() + 1, 0); + + for (size_t i = 0; i < expected_words.size(); ++i) { + std::fill(cur.begin(), cur.end(), 0); + + for (size_t j = 0; j < actual_words.size(); ++j) { + if (expected_words[i] == actual_words[j]) { + cur[j + 1] = prev[j] + 1; + } else { + cur[j + 1] = std::max(prev[j + 1], cur[j]); + } + } + + prev.swap(cur); + } + + const int lcs = prev[actual_words.size()]; + return (2.0 * lcs) / (expected_words.size() + actual_words.size()); +} + +static bool verify_transcription(const std::string & expected, const std::string & actual) { + const double threshold = TRANSCRIPTION_SIMILARITY_THRESHOLD; + + if (threshold >= 1.0) { + if (actual == expected) { + return true; + } + + fprintf(stderr, "\n\n"); + fprintf(stderr, "[Failed] Transcript mismatched\n"); + fprintf(stderr, "expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "actual:\n%s\n", actual.c_str()); + return false; + } + + const double similarity = transcription_lcs_similarity(expected, actual); + printf("\nTranscript similarity: %.6f (threshold %.6f)\n", similarity, threshold); + + if (similarity >= threshold) { + return true; + } + + fprintf(stderr, "\n\nTranscript similarity below threshold: %.6f < %.6f\n", similarity, threshold); + fprintf(stderr, "Expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "Actual:\n%s\n", actual.c_str()); + return false; +} diff --git a/tests/run-tests.sh b/tests/run-tests.sh index ad2b8d3ec09..bc28314a704 100755 --- a/tests/run-tests.sh +++ b/tests/run-tests.sh @@ -21,13 +21,21 @@ cd `dirname $0` # Whisper models models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" "large-v3-turbo" ) +# Parakeet model variants +parakeet_models=( "f16" "f32" "q2_k" "q4_0" "q4_k" "q8_0" ) + # list available models function list_models { printf "\n" - printf " Available models:" + printf " Available whisper models:" for model in "${models[@]}"; do printf " $model" done + printf "\n" + printf " Available parakeet models:" + for model in "${parakeet_models[@]}"; do + printf " parakeet-$model" + done printf "\n\n" } @@ -39,15 +47,37 @@ if [ $# -eq 0 ]; then fi model=$1 -main="../build/bin/whisper-cli" threads="" if [ $# -eq 2 ]; then threads="-t $2" fi -if [ ! -f ../models/ggml-$model.bin ]; then - printf "Model $model not found. Aborting\n" +# Detect parakeet model (prefix "parakeet-" or a bare variant like "f32") +is_parakeet=0 +parakeet_variant="" +if [[ $model == parakeet-* ]]; then + is_parakeet=1 + parakeet_variant="${model#parakeet-}" +fi +for v in "${parakeet_models[@]}"; do + if [[ $model == "$v" ]]; then + is_parakeet=1 + parakeet_variant="$v" + break + fi +done + +if [ $is_parakeet -eq 1 ]; then + main="../build/bin/parakeet-cli" + model_path="../models/ggml-parakeet-tdt-0.6b-v3-${parakeet_variant}.bin" +else + main="../build/bin/whisper-cli" + model_path="../models/ggml-${model}.bin" +fi + +if [ ! -f $model_path ]; then + printf "Model $model not found ($model_path). Aborting\n" list_models exit 1 fi @@ -110,7 +140,11 @@ function run_lang() { fi fi - $main -m ../models/ggml-$model.bin $threads -f $fname_dst -l $lang -otxt 2> /dev/null + if [ $is_parakeet -eq 1 ]; then + $main -m $model_path $threads -f $fname_dst -otxt 2> /dev/null + else + $main -m $model_path $threads -f $fname_dst -l $lang -otxt 2> /dev/null + fi git diff --no-index --word-diff=color --word-diff-regex=. $lang-$i-ref.txt $fname_dst.txt @@ -120,7 +154,7 @@ function run_lang() { run_lang "en" "${urls_en[@]}" -if [[ $model != *.en* ]]; then +if [ $is_parakeet -eq 0 ] && [[ $model != *.en* ]]; then run_lang "es" "${urls_es[@]}" run_lang "it" "${urls_it[@]}" run_lang "pt" "${urls_pt[@]}" diff --git a/tests/test-common-utf8.cpp b/tests/test-common-utf8.cpp new file mode 100644 index 00000000000..91c73a7428d --- /dev/null +++ b/tests/test-common-utf8.cpp @@ -0,0 +1,34 @@ +#include "common-whisper.h" + +#include <cstdlib> +#include <cstdio> +#include <string> + +static void expect_needed(const std::string & input, int expected) { + const int actual = utf8_trailing_bytes_needed(input); + if (actual != expected) { + fprintf(stderr, "expected %d trailing UTF-8 bytes, got %d\n", expected, actual); + std::abort(); + } +} + +int main() { + expect_needed("", 0); + expect_needed("plain ascii", 0); + + const std::string cjk = "\xE4\xBD\xA0"; // U+4F60 + expect_needed(cjk.substr(0, 1), 2); + expect_needed(cjk.substr(0, 2), 1); + expect_needed(cjk, 0); + + const std::string emoji = "\xF0\x9F\x98\x80"; // U+1F600 + expect_needed(emoji.substr(0, 1), 3); + expect_needed(emoji.substr(0, 2), 2); + expect_needed(emoji.substr(0, 3), 1); + expect_needed(emoji, 0); + + expect_needed("\x80\x80", 0); + expect_needed("\xFF", 0); + + return 0; +} diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp new file mode 100644 index 00000000000..22ac4c20e31 --- /dev/null +++ b/tests/test-parakeet-full.cpp @@ -0,0 +1,101 @@ +#include "parakeet.h" +#include "common-whisper.h" +#include "parakeet-verification.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +struct test_state { + bool is_first = true; + std::string transcript; +}; + +void progress_callback(parakeet_context * ctx, parakeet_state * state, int progress, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; +} + +bool encoder_begin_callback(parakeet_context * ctx, parakeet_state * state, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return true; +} + +bool abort_callback(void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return false; // just continue without aborting. +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + test_state * tstate = static_cast<test_state *>(user_data); + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, tstate->is_first, text_buf, sizeof(text_buf)); + + printf("%s", text_buf); + fflush(stdout); + + tstate->transcript += text_buf; + tstate->is_first = false; +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); // no stereo vector + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + test_state tstate; + params.new_token_callback = token_callback; + params.new_token_callback_user_data = &tstate; + bool progress_callback_called = false; + params.progress_callback = progress_callback; + params.progress_callback_user_data = &progress_callback_called; + bool encoder_begin_callback_called = false; + params.encoder_begin_callback = encoder_begin_callback; + params.encoder_begin_callback_user_data = &encoder_begin_callback_called; + bool abort_callback_called = false; + params.abort_callback = abort_callback; + params.abort_callback_user_data = &abort_callback_called; + + int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + assert(progress_callback_called); + assert(encoder_begin_callback_called); + assert(abort_callback_called); + + const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH); + const bool transcript_matches = verify_transcription(expected, tstate.transcript); + + parakeet_free(pctx); + + if (!transcript_matches) { + return 1; + } + + printf("\nTest passed: parakeet_full succeeded!\n"); + return 0; +} diff --git a/tests/test-parakeet.cpp b/tests/test-parakeet.cpp new file mode 100644 index 00000000000..83237c600ac --- /dev/null +++ b/tests/test-parakeet.cpp @@ -0,0 +1,99 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + // Load the sample audio file + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + params.new_segment_callback = segment_callback; + params.new_segment_callback_user_data = nullptr; + parakeet_state * state = parakeet_init_state(pctx); + + int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\nTest passed: Parakeet model loaded and freed successfully\n"); + return 0; +}