diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile index 955a2962ff4..8e830d46251 100644 --- a/.devops/intel.Dockerfile +++ b/.devops/intel.Dockerfile @@ -1,4 +1,4 @@ -ARG ONEAPI_VERSION=2025.3.2-0-devel-ubuntu24.04 +ARG ONEAPI_VERSION=2025.3.3-0-devel-ubuntu24.04 ## Build Image diff --git a/.devops/openvino.Dockerfile b/.devops/openvino.Dockerfile index 3ee4dd20180..31b58736d7e 100644 --- a/.devops/openvino.Dockerfile +++ b/.devops/openvino.Dockerfile @@ -2,7 +2,19 @@ ARG OPENVINO_VERSION_MAJOR=2026.0 ARG OPENVINO_VERSION_FULL=2026.0.0.20965.c6d6a13a886 ARG UBUNTU_VERSION=24.04 -# Optional proxy build arguments - empty by default +# Intel GPU driver versions. https://github.com/intel/compute-runtime/releases +ARG IGC_VERSION=v2.30.1 +ARG IGC_VERSION_FULL=2_2.30.1+20950 +ARG COMPUTE_RUNTIME_VERSION=26.09.37435.1 +ARG COMPUTE_RUNTIME_VERSION_FULL=26.09.37435.1-0 +ARG IGDGMM_VERSION=22.9.0 + +# Intel NPU driver versions. https://github.com/intel/linux-npu-driver/releases +ARG NPU_DRIVER_VERSION=v1.32.0 +ARG NPU_DRIVER_FULL=v1.32.0.20260402-23905121947 +ARG LIBZE1_VERSION=1.27.0-1~24.04~ppa2 + +# Optional proxy build arguments ARG http_proxy= ARG https_proxy= @@ -78,13 +90,47 @@ ARG http_proxy ARG https_proxy RUN apt-get update \ - && apt-get install -y libgomp1 libtbb12 curl \ + && apt-get install -y libgomp1 libtbb12 curl wget ocl-icd-libopencl1 \ && apt autoremove -y \ && apt clean -y \ && rm -rf /tmp/* /var/tmp/* \ && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ && find /var/cache -type f -delete +# Install GPU drivers +ARG IGC_VERSION +ARG IGC_VERSION_FULL +ARG COMPUTE_RUNTIME_VERSION +ARG COMPUTE_RUNTIME_VERSION_FULL +ARG IGDGMM_VERSION +RUN mkdir /tmp/neo/ && cd /tmp/neo/ \ + && wget https://github.com/intel/intel-graphics-compiler/releases/download/${IGC_VERSION}/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/intel-graphics-compiler/releases/download/${IGC_VERSION}/intel-igc-opencl-${IGC_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-ocloc-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-ocloc_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-opencl-icd-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/intel-opencl-icd_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/libigdgmm12_${IGDGMM_VERSION}_amd64.deb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/libze-intel-gpu1-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \ + && wget https://github.com/intel/compute-runtime/releases/download/${COMPUTE_RUNTIME_VERSION}/libze-intel-gpu1_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \ + && dpkg --install *.deb \ + && rm -rf /tmp/neo/ + +# Install NPU drivers +ARG NPU_DRIVER_VERSION +ARG NPU_DRIVER_FULL +ARG LIBZE1_VERSION +RUN mkdir /tmp/npu/ && cd /tmp/npu/ \ + && wget https://github.com/intel/linux-npu-driver/releases/download/${NPU_DRIVER_VERSION}/linux-npu-driver-${NPU_DRIVER_FULL}-ubuntu2404.tar.gz \ + && tar -xf linux-npu-driver-${NPU_DRIVER_FULL}-ubuntu2404.tar.gz \ + && dpkg --install *.deb \ + && rm -rf /tmp/npu/ + +RUN cd /tmp \ + && wget https://snapshot.ppa.launchpadcontent.net/kobuk-team/intel-graphics/ubuntu/20260324T100000Z/pool/main/l/level-zero-loader/libze1_${LIBZE1_VERSION}_amd64.deb \ + && dpkg --install libze1_${LIBZE1_VERSION}_amd64.deb \ + && rm libze1_${LIBZE1_VERSION}_amd64.deb + COPY --from=build /app/lib/ /app/ ### Full (all binaries) diff --git a/.github/workflows/build-and-test-snapdragon.yml b/.github/workflows/build-and-test-snapdragon.yml new file mode 100644 index 00000000000..7eb204ea2a6 --- /dev/null +++ b/.github/workflows/build-and-test-snapdragon.yml @@ -0,0 +1,113 @@ +name: CI (snapdragon) + +on: + workflow_dispatch: + push: + branches: + - master + paths: + - '.github/workflows/build-and-test-snapdragon.yml' + - 'ggml/include/ggml-hexagon.h' + - 'ggml/src/ggml-hexagon/**' + - 'docs/backend/snapdragon/**' + - 'scripts/snapdragon/**' + - 'CMakePresets.json' + + pull_request: + types: [opened, synchronize, reopened] + paths: + - '.github/workflows/build-and-test-snapdragon.yml' + - 'ggml/include/ggml-hexagon.h' + - 'ggml/src/ggml-hexagon/**' + - 'docs/backend/snapdragon/**' + - 'scripts/snapdragon/**' + - 'CMakePresets.json' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + android-ndk-snapdragon: + runs-on: ubuntu-latest + container: + image: 'ghcr.io/snapdragon-toolchain/arm64-android:v0.3' + defaults: + run: + shell: bash + + steps: + - name: Clone + uses: actions/checkout@v6 + with: + fetch-depth: 0 + lfs: false + + - name: Build Llama.CPP for Snapdragon Android + id: build_llama_cpp_snapdragon_android + run: | + cp docs/backend/snapdragon/CMakeUserPresets.json . + cmake --preset arm64-android-snapdragon-release -B build + cmake --build build + cmake --install build --prefix pkg-adb/llama.cpp + + - name: Upload Llama.CPP Snapdragon Android Build Artifact + if: ${{ always() && steps.build_llama_cpp_snapdragon_android.outcome == 'success' }} + uses: actions/upload-artifact@v6 + with: + name: llama-cpp-android-arm64-snapdragon + path: pkg-adb/llama.cpp + + check-secret: + runs-on: ubuntu-latest + outputs: + has-key: ${{ steps.check.outputs.has-key }} + steps: + - id: check + run: echo "has-key=${{ secrets.QDC_API_KEY != '' }}" >> "$GITHUB_OUTPUT" + + test-snapdragon-qdc: + name: Test on QDC Android Device (${{ matrix.device }}) + needs: [android-ndk-snapdragon, check-secret] + if: needs.check-secret.outputs.has-key == 'true' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + device: [SM8750, SM8650, SM8850] + + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Download build artifact + uses: actions/download-artifact@v4 + with: + name: llama-cpp-android-arm64-snapdragon + path: pkg-snapdragon/ + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + cache: pip + + - name: Install QDC SDK wheel + run: | + curl -fSL -o qdc_sdk.zip https://softwarecenter.qualcomm.com/api/download/software/tools/Qualcomm_Device_Cloud_SDK/All/0.2.3/qualcomm_device_cloud_sdk-0.2.3.zip + unzip qdc_sdk.zip -d qdc_sdk + pip install qdc_sdk/qualcomm_device_cloud_sdk-0.2.3-py3-none-any.whl + + - name: Run QDC tests (${{ matrix.device }}) + run: | + python scripts/snapdragon/qdc/run_qdc_jobs.py \ + --test all \ + --pkg-dir pkg-snapdragon/llama.cpp \ + --model-url "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf" \ + --device ${{ matrix.device }} + env: + QDC_API_KEY: ${{ secrets.QDC_API_KEY }} + + - name: Cleanup + if: always() + run: rm -rf pkg-snapdragon qdc_sdk qdc_sdk.zip diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml index b38a793f186..5d88305a4f0 100644 --- a/.github/workflows/build-android.yml +++ b/.github/workflows/build-android.yml @@ -1,26 +1,24 @@ name: CI (android) on: - workflow_dispatch: # allows manual triggering + workflow_dispatch: push: branches: - master - paths: [ - '.github/workflows/build-android.yml', - '**/CMakeLists.txt', - '**/.cmake', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp' - ] + paths: + - '.github/workflows/build-android.yml' + - '**/CMakeLists.txt' + - '**/.cmake' + - '**/*.h' + - '**/*.hpp' + - '**/*.c' + - '**/*.cpp' pull_request: types: [opened, synchronize, reopened] - paths: [ - '.github/workflows/build-android.yml', - 'examples/llama.android/**' - ] + paths: + - '.github/workflows/build-android.yml' + - 'examples/llama.android/**' concurrency: group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} @@ -67,35 +65,24 @@ jobs: defaults: run: shell: bash - strategy: - matrix: - include: - - build: 'arm64-cpu' - defines: '-D ANDROID_ABI=arm64-v8a -D ANDROID_PLATFORM=android-31 -D CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -D GGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm -G Ninja -D LLAMA_OPENSSL=OFF -D GGML_OPENMP=OFF' - - build: 'arm64-snapdragon' - defines: '--preset arm64-android-snapdragon-release' steps: - name: Clone - id: checkout uses: actions/checkout@v6 with: fetch-depth: 0 lfs: false - - name: Build Llama.CPP for Hexagon Android - id: build_llama_cpp_hexagon_android + - name: Build + id: ndk_build run: | - if [[ "${{ matrix.build }}" == "arm64-snapdragon" ]]; then - cp docs/backend/snapdragon/CMakeUserPresets.json . - fi - cmake ${{ matrix.defines }} -B build + cmake -D ANDROID_ABI=arm64-v8a -D ANDROID_PLATFORM=android-31 -D CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -D GGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm -G Ninja -D LLAMA_OPENSSL=OFF -D GGML_OPENMP=OFF -B build cmake --build build cmake --install build --prefix pkg-adb/llama.cpp - - name: Upload Llama.CPP Hexagon Android Build Artifact - if: ${{ always() && steps.build_llama_cpp_hexagon_android.outcome == 'success' }} + - name: Upload Android Build Artifact + if: ${{ always() && steps.ndk_build.outcome == 'success' }} uses: actions/upload-artifact@v6 with: - name: llama-cpp-android-${{ matrix.build }} + name: llama-cpp-android-arm64-cpu path: pkg-adb/llama.cpp diff --git a/.github/workflows/build-cross.yml b/.github/workflows/build-cross.yml index 74508129ac5..aef45afdeac 100644 --- a/.github/workflows/build-cross.yml +++ b/.github/workflows/build-cross.yml @@ -246,6 +246,7 @@ jobs: apt-get install -y --no-install-recommends \ build-essential \ glslc \ + spirv-headers \ gcc-14-loongarch64-linux-gnu \ g++-14-loongarch64-linux-gnu \ libvulkan-dev:loong64 diff --git a/.github/workflows/build-openvino.yml b/.github/workflows/build-openvino.yml new file mode 100644 index 00000000000..f7177f6be37 --- /dev/null +++ b/.github/workflows/build-openvino.yml @@ -0,0 +1,120 @@ +name: CI (openvino) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build-openvino.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-openvino.yml', + 'ggml/src/ggml-openvino/**' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + GGML_NLOOP: 3 + GGML_N_THREADS: 1 + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + +jobs: + ubuntu-24-openvino: + name: ubuntu-24-openvino-${{ matrix.openvino_device }} + + concurrency: + group: openvino-${{ matrix.variant }}-${{ github.head_ref || github.ref }} + cancel-in-progress: false + + strategy: + matrix: + include: + - variant: cpu + runner: '"ubuntu-24.04"' + openvino_device: "CPU" + - variant: gpu + runner: '["self-hosted","Linux","Intel","OpenVINO"]' + openvino_device: "GPU" + + runs-on: ${{ fromJSON(matrix.runner) }} + + env: + # Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile + OPENVINO_VERSION_MAJOR: "2026.0" + OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + if: runner.environment == 'github-hosted' + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ubuntu-24-openvino-${{ matrix.variant }}-no-preset-v1 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install -y build-essential libssl-dev libtbb12 cmake ninja-build python3-pip + sudo apt-get install -y ocl-icd-opencl-dev opencl-headers opencl-clhpp-headers intel-opencl-icd + + - name: Use OpenVINO Toolkit Cache + if: runner.environment == 'github-hosted' + uses: actions/cache@v5 + id: cache-openvino + with: + path: ./openvino_toolkit + key: openvino-toolkit-v${{ env.OPENVINO_VERSION_FULL }}-${{ runner.os }} + + - name: Setup OpenVINO Toolkit + if: steps.cache-openvino.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-openvino + with: + path: ./openvino_toolkit + version_major: ${{ env.OPENVINO_VERSION_MAJOR }} + version_full: ${{ env.OPENVINO_VERSION_FULL }} + + - name: Install OpenVINO dependencies + run: | + cd ./openvino_toolkit + chmod +x ./install_dependencies/install_openvino_dependencies.sh + echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh + + - name: Build + id: cmake_build + run: | + source ./openvino_toolkit/setupvars.sh + cmake -B build/ReleaseOV -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENVINO=ON + time cmake --build build/ReleaseOV --config Release -j $(nproc) + + - name: Test + id: cmake_test + # TODO: fix and re-enable the `test-llama-archs` test below + run: | + cd ${{ github.workspace }} + if [ "${{ matrix.openvino_device }}" = "GPU" ]; then + export GGML_OPENVINO_DEVICE=GPU + fi + ctest --test-dir build/ReleaseOV -L main -E "test-llama-archs" --verbose --timeout 2000 diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml index 0efe8771625..e9148dd7399 100644 --- a/.github/workflows/build-self-hosted.yml +++ b/.github/workflows/build-self-hosted.yml @@ -97,6 +97,36 @@ jobs: vulkaninfo --summary GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # TODO: investigate slight precision issues in some operations for test-backend-ops on the WebGPU backend. + #ggml-ci-nvidia-webgpu: + # runs-on: [self-hosted, Linux, NVIDIA] + + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v6 + + # - name: Dawn Dependency + # id: dawn-depends + # run: | + # DAWN_VERSION="v20260317.182325" + # DAWN_OWNER="google" + # DAWN_REPO="dawn" + # DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-ubuntu-latest-Release" + # echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz" + # curl -L -o artifact.tar.gz \ + # "https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz" + # mkdir dawn + # tar -xvf artifact.tar.gz -C dawn --strip-components=1 + + # - name: Test + # id: ggml-ci + # run: | + # GG_BUILD_WEBGPU=1 \ + # GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \ + # GG_BUILD_WEBGPU_DAWN_DIR="$GITHUB_WORKSPACE/dawn/lib64/cmake/Dawn" \ + # bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # TODO: provision AMX-compatible machine #ggml-ci-cpu-amx: # runs-on: [self-hosted, Linux, CPU, AMX] @@ -235,6 +265,10 @@ jobs: ggml-ci-intel-openvino-gpu-low-perf: runs-on: [self-hosted, Linux, Intel, OpenVINO] + concurrency: + group: openvino-gpu-${{ github.head_ref || github.ref }} + cancel-in-progress: false + env: # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile OPENVINO_VERSION_MAJOR: "2026.0" diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml new file mode 100644 index 00000000000..2a6642292e6 --- /dev/null +++ b/.github/workflows/build-sycl.yml @@ -0,0 +1,142 @@ +name: CI (sycl) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build-sycl.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp' + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-sycl.yml', + 'ggml/src/ggml-sycl/**' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + GGML_NLOOP: 3 + GGML_N_THREADS: 1 + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + +jobs: + + ubuntu-24-sycl: + strategy: + matrix: + build: [fp32, fp16] + include: + - build: fp32 + fp16: OFF + - build: fp16 + fp16: ON + + runs-on: ubuntu-24.04 + + env: + ONEAPI_ROOT: /opt/intel/oneapi/ + ONEAPI_INSTALLER_VERSION: "2025.3.3" + + continue-on-error: true + + steps: + - uses: actions/checkout@v6 + + - name: Use oneAPI Installation Cache + uses: actions/cache@v5 + id: cache-sycl + with: + path: ${{ env.ONEAPI_ROOT }} + key: oneAPI-${{ env.ONEAPI_INSTALLER_VERSION }}-${{ runner.os }} + + - name: Download & Install oneAPI + shell: bash + if: steps.cache-sycl.outputs.cache-hit != 'true' + run: | + cd /tmp + wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/56f7923a-adb8-43f3-8b02-2b60fcac8cab/intel-deep-learning-essentials-2025.3.3.16_offline.sh -O intel-deep-learning-essentials_offline.sh + sudo bash intel-deep-learning-essentials_offline.sh -s -a --silent --eula accept + + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ubuntu-24-sycl-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + cmake -B build \ + -G "Ninja" \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx \ + -DLLAMA_OPENSSL=OFF \ + -DGGML_NATIVE=OFF \ + -DGGML_SYCL_F16=${{ matrix.fp16 }} + time cmake --build build --config Release -j $(nproc) + + windows-latest-sycl: + runs-on: windows-2022 + + defaults: + run: + shell: bash + + env: + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b60765d1-2b85-4e85-86b6-cb0e9563a699/intel-deep-learning-essentials-2025.3.3.18_offline.exe + WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel + ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" + ONEAPI_INSTALLER_VERSION: "2025.3.3" + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Use oneAPI Installation Cache + uses: actions/cache@v5 + id: cache-sycl + with: + path: ${{ env.ONEAPI_ROOT }} + key: oneAPI-${{ env.ONEAPI_INSTALLER_VERSION }}-${{ runner.os }} + + - name: Download & Install oneAPI + shell: bash + if: steps.cache-sycl.outputs.cache-hit != 'true' + run: | + scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: windows-latest-sycl + variant: ccache + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + # TODO: add ssl support ; we will also need to modify win-build-sycl.bat to accept user-specified args + + - name: Build + id: cmake_build + run: examples/sycl/win-build-sycl.bat diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 28c8665bd8b..21eb4d97b3e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -555,186 +555,6 @@ jobs: -DGGML_MUSA=ON time cmake --build build --config Release -j $(nproc) - ubuntu-22-sycl: - runs-on: ubuntu-22.04 - - continue-on-error: true - - steps: - - uses: actions/checkout@v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel - - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ubuntu-22-sycl - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - cmake -B build \ - -DGGML_SYCL=ON \ - -DCMAKE_C_COMPILER=icx \ - -DCMAKE_CXX_COMPILER=icpx - time cmake --build build --config Release -j $(nproc) - - ubuntu-22-sycl-fp16: - runs-on: ubuntu-22.04 - - continue-on-error: true - - steps: - - uses: actions/checkout@v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev ninja-build - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel - - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ubuntu-22-sycl-fp16 - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - cmake -B build \ - -G "Ninja" \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_SYCL=ON \ - -DCMAKE_C_COMPILER=icx \ - -DCMAKE_CXX_COMPILER=icpx \ - -DGGML_SYCL_F16=ON - time cmake --build build --config Release -j $(nproc) - - ubuntu-24-openvino: - name: ubuntu-24-openvino-${{ matrix.openvino_device }} - strategy: - matrix: - include: - - variant: cpu - runner: '"ubuntu-24.04"' - openvino_device: "CPU" - - variant: gpu - runner: '["self-hosted","Linux","X64","Intel"]' - openvino_device: "GPU" - - runs-on: ${{ fromJSON(matrix.runner) }} - - env: - # Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile - OPENVINO_VERSION_MAJOR: "2026.0" - OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886" - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - if: runner.environment == 'github-hosted' - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ubuntu-24-openvino-${{ matrix.variant }}-no-preset-v1 - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install -y build-essential libssl-dev libtbb12 cmake ninja-build python3-pip - sudo apt-get install -y ocl-icd-opencl-dev opencl-headers opencl-clhpp-headers intel-opencl-icd - - - name: Use OpenVINO Toolkit Cache - if: runner.environment == 'github-hosted' - uses: actions/cache@v5 - id: cache-openvino - with: - path: ./openvino_toolkit - key: openvino-toolkit-v${{ env.OPENVINO_VERSION_FULL }}-${{ runner.os }} - - - name: Setup OpenVINO Toolkit - if: steps.cache-openvino.outputs.cache-hit != 'true' - uses: ./.github/actions/linux-setup-openvino - with: - path: ./openvino_toolkit - version_major: ${{ env.OPENVINO_VERSION_MAJOR }} - version_full: ${{ env.OPENVINO_VERSION_FULL }} - - - name: Install OpenVINO dependencies - run: | - cd ./openvino_toolkit - chmod +x ./install_dependencies/install_openvino_dependencies.sh - echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh - - - name: Build - id: cmake_build - run: | - source ./openvino_toolkit/setupvars.sh - cmake -B build/ReleaseOV -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENVINO=ON - time cmake --build build/ReleaseOV --config Release -j $(nproc) - - - name: Test - id: cmake_test - # TODO: fix and re-enable the `test-llama-archs` test below - run: | - cd ${{ github.workspace }} - if [ "${{ matrix.openvino_device }}" = "GPU" ]; then - export GGML_OPENVINO_DEVICE=GPU - fi - ctest --test-dir build/ReleaseOV -L main -E "test-llama-archs" --verbose --timeout 2000 windows-latest: runs-on: windows-2025 @@ -943,39 +763,6 @@ jobs: cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release - windows-latest-sycl: - runs-on: windows-2022 - - defaults: - run: - shell: bash - - env: - WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe - WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel - ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: windows-latest-sycl - variant: ccache - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Install - run: | - scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL - - # TODO: add ssl support ; we will also need to modify win-build-sycl.bat to accept user-specified args - - - name: Build - id: cmake_build - run: examples/sycl/win-build-sycl.bat windows-latest-hip: runs-on: windows-2022 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8a49715b395..924f6cd3fe3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -598,15 +598,29 @@ jobs: shell: bash env: - WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b60765d1-2b85-4e85-86b6-cb0e9563a699/intel-deep-learning-essentials-2025.3.3.18_offline.exe WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" + ONEAPI_INSTALLER_VERSION: "2025.3.3" steps: - name: Clone id: checkout uses: actions/checkout@v6 + - name: Use oneAPI Installation Cache + uses: actions/cache@v5 + id: cache-sycl + with: + path: ${{ env.ONEAPI_ROOT }} + key: oneAPI-${{ env.ONEAPI_INSTALLER_VERSION }}-${{ runner.os }} + + - name: Download & Install oneAPI + shell: bash + if: steps.cache-sycl.outputs.cache-hit != 'true' + run: | + scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL + - name: ccache uses: ggml-org/ccache-action@v1.2.21 with: @@ -614,10 +628,6 @@ jobs: variant: ccache evict-old-files: 1d - - name: Install - run: | - scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL - - name: Build id: cmake_build shell: cmd @@ -670,6 +680,82 @@ jobs: path: llama-bin-win-sycl-x64.zip name: llama-bin-win-sycl-x64.zip + ubuntu-24-sycl: + strategy: + matrix: + build: [fp32, fp16] + include: + - build: fp32 + fp16: OFF + - build: fp16 + fp16: ON + + runs-on: ubuntu-24.04 + + env: + ONEAPI_ROOT: /opt/intel/oneapi/ + ONEAPI_INSTALLER_VERSION: "2025.3.3" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Use oneAPI Installation Cache + uses: actions/cache@v5 + id: cache-sycl + with: + path: ${{ env.ONEAPI_ROOT }} + key: oneAPI-${{ env.ONEAPI_INSTALLER_VERSION }}-${{ runner.os }} + + - name: Download & Install oneAPI + shell: bash + if: steps.cache-sycl.outputs.cache-hit != 'true' + run: | + cd /tmp + wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/56f7923a-adb8-43f3-8b02-2b60fcac8cab/intel-deep-learning-essentials-2025.3.3.16_offline.sh -O intel-deep-learning-essentials_offline.sh + sudo bash intel-deep-learning-essentials_offline.sh -s -a --silent --eula accept + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ubuntu-24-sycl-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + cmake -B build \ + -G "Ninja" \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx \ + -DLLAMA_OPENSSL=OFF \ + -DGGML_NATIVE=OFF \ + -DGGML_SYCL_F16=${{ matrix.fp16 }} + time cmake --build build --config Release -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-sycl-${{ matrix.build }}-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . + + - name: Upload artifacts + uses: actions/upload-artifact@v6 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-sycl-${{ matrix.build }}-x64.tar.gz + name: llama-bin-ubuntu-sycl-${{ matrix.build }}-x64.tar.gz + ubuntu-22-rocm: runs-on: ubuntu-22.04 @@ -687,6 +773,11 @@ jobs: with: fetch-depth: 0 + - name: Free up disk space + uses: ggml-org/free-disk-space@v1.3.1 + with: + tool-cache: true + - name: ccache uses: ggml-org/ccache-action@v1.2.21 with: @@ -1040,6 +1131,7 @@ jobs: - ubuntu-cpu - ubuntu-vulkan - ubuntu-24-openvino + - ubuntu-24-sycl - android-arm64 - macOS-cpu - ios-xcode-build @@ -1128,6 +1220,8 @@ jobs: - [Ubuntu arm64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-arm64.tar.gz) - [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz) - [Ubuntu x64 (OpenVINO)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ needs.ubuntu-24-openvino.outputs.openvino_version }}-x64.tar.gz) + - [Ubuntu x64 (SYCL FP32)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-sycl-fp32-x64.tar.gz) + - [Ubuntu x64 (SYCL FP16)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-sycl-fp16-x64.tar.gz) **Android:** - [Android arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-android-arm64.tar.gz) diff --git a/.gitignore b/.gitignore index 15dc4014f43..6136524d75a 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,5 @@ poetry.toml /.windsurf/ # emscripten a.out.* + +AGENTS.local.md diff --git a/CODEOWNERS b/CODEOWNERS index 67d5c5a9f11..612fcdda1c0 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -23,6 +23,7 @@ /ci/ @ggerganov /cmake/ @ggerganov /common/ @ggml-org/llama-common +/common/fit.* @JohannesGaessler /common/jinja/ @CISC /common/ngram-map.* @srogmann /convert_*.py @CISC diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 7a911c63e9d..1a56c25857f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -73,6 +73,8 @@ add_library(${TARGET} debug.h download.cpp download.h + fit.cpp + fit.h hf-cache.cpp hf-cache.h http.h diff --git a/common/arg.cpp b/common/arg.cpp index 6f22f781915..03596ced4d8 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa hf_tag = "default"; } - std::string model_endpoint = get_model_endpoint(); + std::string model_endpoint = common_get_model_endpoint(); auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini"; // prepare local path for caching @@ -1316,13 +1316,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( - {"--clear-idle"}, - {"--no-clear-idle"}, + {"--cache-idle-slots"}, + {"--no-cache-idle-slots"}, "save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)", [](common_params & params, bool value) { - params.clear_idle = value; + params.cache_idle_slots = value; } - ).set_env("LLAMA_ARG_CLEAR_IDLE").set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_env("LLAMA_ARG_CACHE_IDLE_SLOTS").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, @@ -2426,6 +2426,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_env("LLAMA_ARG_FIT")); + add_opt(common_arg( + { "-fitp", "--fit-print" }, "[on|off]", + string_format("print the estimated required memory ('on' or 'off', default: '%s')", params.fit_params_print ? "on" : "off"), + [](common_params & params, const std::string & value) { + if (is_truthy(value)) { + params.fit_params_print = true; + } else if (is_falsey(value)) { + params.fit_params_print = false; + } else { + throw std::runtime_error( + string_format("error: unknown value for --fit-print: '%s'\n", value.c_str())); + } + } + ).set_examples({LLAMA_EXAMPLE_FIT_PARAMS}).set_env("LLAMA_ARG_FIT_ESTIMATE")); add_opt(common_arg( { "-fitt", "--fit-target" }, "MiB0,MiB1,MiB2,...", string_format("target margin per device for --fit, comma-separated list of values, " @@ -3073,7 +3087,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.use_jinja = value; } - ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" @@ -3108,14 +3122,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)", [](common_params & params, int value) { if (value < -1) { throw std::invalid_argument("invalid value"); } - params.reasoning_budget = value; + params.sampling.reasoning_budget_tokens = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET")); add_opt(common_arg( {"--reasoning-budget-message"}, "MESSAGE", "message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)", [](common_params & params, const std::string & value) { - params.reasoning_budget_message = value; + params.sampling.reasoning_budget_message = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE")); add_opt(common_arg( @@ -3129,7 +3143,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.chat_template = value; } - ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); add_opt(common_arg( {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", string_format( @@ -3453,6 +3467,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.p_min = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN")); + add_opt(common_arg( + {"--eagle3"}, + "use EAGLE3 speculative decoding with the draft model", + [](common_params & params) { + params.speculative.eagle3 = true; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--dflash"}, + "use DFlash speculative decoding with the draft model", + [](common_params & params) { + params.speculative.dflash = true; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-cd", "--ctx-size-draft"}, "N", string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), @@ -3888,6 +3916,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--spec-default"}, + string_format("enable default speculative decoding config"), + [](common_params & params) { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + params.speculative.ngram_size_n = 24; + params.speculative.n_min = 48; + params.speculative.n_max = 64; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + return ctx_arg; } diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index c6431b89852..453559a4b04 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -443,14 +443,14 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte if (!format.per_call_start.empty()) { auto wrapped_call = format.per_call_start + p.space() + tool_choice + p.space() + format.per_call_end; if (inputs.parallel_tool_calls) { - tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call) + p.space()); } else { - tool_calls = p.trigger_rule("tool-call", wrapped_call); + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.space()); } if (!format.section_start.empty()) { tool_calls = p.trigger_rule("tool-calls", p.literal(format.section_start) + p.space() + tool_calls + p.space() + - (format.section_end.empty() ? p.end() : p.literal(format.section_end))); + (format.section_end.empty() ? p.end() : p.literal(format.section_end) + p.space())); } } else { std::string separator = ", "; // Default diff --git a/common/chat.cpp b/common/chat.cpp index e27b6c3413c..159d625de99 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -397,6 +397,25 @@ json common_chat_msgs_to_json_oaicompat(const std::vector & msg return render_message_to_json(msgs, c); } +json common_chat_tools_to_json_oaicompat(const std::vector & tools) { + if (tools.empty()) { + return json(); + } + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + { "type", "function" }, + { "function", { + { "name", tool.name }, + { "description", tool.description }, + { "parameters", json::parse(tool.parameters) }, + }}, + }); + } + return result; +} + std::vector common_chat_tools_parse_oaicompat(const json & tools) { std::vector result; @@ -432,56 +451,6 @@ std::vector common_chat_tools_parse_oaicompat(const json & too return result; } -json common_chat_tools_to_json_oaicompat(const std::vector & tools) { - if (tools.empty()) { - return json(); - } - - auto result = json::array(); - for (const auto & tool : tools) { - result.push_back({ - { "type", "function" }, - { "function", - { - { "name", tool.name }, - { "description", tool.description }, - { "parameters", json::parse(tool.parameters) }, - } }, - }); - } - return result; -} - -json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { - json delta = json::object(); - if (!diff.reasoning_content_delta.empty()) { - delta["reasoning_content"] = diff.reasoning_content_delta; - } - if (!diff.content_delta.empty()) { - delta["content"] = diff.content_delta; - } - if (diff.tool_call_index != std::string::npos) { - json tool_call; - tool_call["index"] = diff.tool_call_index; - if (!diff.tool_call_delta.id.empty()) { - tool_call["id"] = diff.tool_call_delta.id; - tool_call["type"] = "function"; - } - if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { - json function = json::object(); - if (!diff.tool_call_delta.name.empty()) { - function["name"] = diff.tool_call_delta.name; - } - if (!diff.tool_call_delta.arguments.empty()) { - function["arguments"] = diff.tool_call_delta.arguments; - } - tool_call["function"] = function; - } - delta["tool_calls"] = json::array({ tool_call }); - } - return delta; -} - bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -575,6 +544,26 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp return tmpls->has_explicit_template; } +// LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list +// and <|tool_call_start|>[...]<|tool_call_end|> around each tool call +static bool is_lfm2_template(const std::string & src) { + return src.find("<|tool_list_start|>") != std::string::npos && + src.find("<|tool_list_end|>") != std::string::npos; +} + +common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates) { + common_chat_prompt_preset asr_preset; + asr_preset.system = ""; + asr_preset.user = "Transcribe audio to text"; + + if (chat_templates && chat_templates->template_default && is_lfm2_template(chat_templates->template_default->source())) { + asr_preset.system = "Perform ASR."; + asr_preset.user = ""; + } + + return asr_preset; +} + std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) { if (!variant.empty()) { if (variant == "tool_use") { @@ -2084,10 +2073,7 @@ std::optional common_chat_try_specialized_template( return common_chat_params_init_kimi_k2(tmpl, params); } - // LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list - // and <|tool_call_start|>[...]<|tool_call_end|> around each tool call - if (src.find("<|tool_list_start|>") != std::string::npos && - src.find("<|tool_list_end|>") != std::string::npos) { + if (is_lfm2_template(src)) { LOG_DBG("Using specialized template: LFM2\n"); return common_chat_params_init_lfm2(tmpl, params); } @@ -2334,7 +2320,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars ? input : params.generation_prompt + input; - LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str()); + //LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str()); common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT; if (params.debug) { @@ -2396,4 +2382,3 @@ std::map common_chat_templates_get_caps(const common_chat_tem GGML_ASSERT(chat_templates->template_default != nullptr); return chat_templates->template_default->caps.to_map(); } - diff --git a/common/chat.h b/common/chat.h index b06ca37fd74..01a47b383bf 100644 --- a/common/chat.h +++ b/common/chat.h @@ -256,14 +256,13 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates * // Parses a JSON array of messages in OpenAI's chat completion API format. std::vector common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages); +std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); + // DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); -std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector & tools); -nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); - // get template caps, useful for reporting to server /props endpoint std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates); @@ -275,3 +274,11 @@ std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, autoparser::generation_params & params); + +// specialized per-task preset +struct common_chat_prompt_preset { + std::string system; + std::string user; +}; + +common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates); diff --git a/common/common.cpp b/common/common.cpp index d3f1cee394c..f7f33e8172a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -3,6 +3,7 @@ #include "build-info.h" #include "common.h" +#include "fit.h" #include "log.h" #include "llama.h" #include "sampling.h" @@ -1147,7 +1148,7 @@ common_init_result::common_init_result(common_params & params) : if (params.fit_params) { LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); - llama_params_fit(params.model.path.c_str(), &mparams, &cparams, + common_fit_params(params.model.path.c_str(), &mparams, &cparams, params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), @@ -1382,7 +1383,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { common_init_result::~common_init_result() = default; -std::string get_model_endpoint() { +std::string common_get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. const char * hf_endpoint_env = getenv("HF_ENDPOINT"); @@ -1397,6 +1398,42 @@ std::string get_model_endpoint() { return model_endpoint; } +common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { + auto * mem = llama_get_memory(ctx); + if (mem == nullptr) { + return COMMON_CONTEXT_SEQ_RM_TYPE_NO; + } + + common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART; + + llama_memory_clear(mem, true); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size())); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + goto done; + } + + // try to remove the last tokens + if (!llama_memory_seq_rm(mem, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + goto done; + } + +done: + llama_memory_clear(mem, true); + llama_synchronize(ctx); + + return res; +} + void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { std::vector loras; std::vector scales; diff --git a/common/common.h b/common/common.h index 81c26955656..27c85d32bb4 100644 --- a/common/common.h +++ b/common/common.h @@ -11,7 +11,6 @@ #include #include #include -#include #include #include @@ -160,6 +159,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT, // draft model COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model + COMMON_SPECULATIVE_TYPE_DFLASH, // dflash draft model COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values @@ -275,6 +275,7 @@ struct common_params_sampling { std::vector reasoning_budget_start; // start tag token sequence std::vector reasoning_budget_end; // end tag token sequence std::vector reasoning_budget_forced; // forced sequence (message + end tag) + std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool backend_sampling = false; @@ -303,15 +304,15 @@ struct common_params_speculative { // general-purpose speculative decoding parameters int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) // ngram-based speculative decoding - uint16_t ngram_size_n = 12; // ngram size for lookup - uint16_t ngram_size_m = 48; // mgram size for speculative tokens - uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed + uint16_t ngram_size_n = 12; // ngram size for lookup + uint16_t ngram_size_m = 48; // mgram size for speculative tokens + uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed std::shared_ptr ngram_mod; @@ -322,10 +323,14 @@ struct common_params_speculative { struct common_params_model mparams_dft; + llama_model * model_tgt = nullptr; // the target model llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts llama_context_params cparams_dft; // these are the parameters for the draft llama_context + bool eagle3 = false; // use EAGLE3 speculative decoding + bool dflash = false; // use DFlash speculative decoding + int32_t n_ctx = 0; // draft context size int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) @@ -421,11 +426,12 @@ struct common_params { // offload params std::vector devices; // devices to use for offloading - int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - bool fit_params = true; // whether to fit unset model/context parameters to free device memory - int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + bool fit_params = true; // whether to fit unset model/context parameters to free device memory + bool fit_params_print = false; // print the estimated required memory to run the model + int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use // margin per device in bytes for fitting parameters to free memory: std::vector fit_params_target = std::vector(llama_max_devices(), 1024 * 1024*1024); @@ -567,7 +573,7 @@ struct common_params { int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting bool cache_prompt = true; // whether to enable prompt caching - bool clear_idle = true; // save and clear idle slots upon starting a new task + bool cache_idle_slots = true; // save and clear idle slots upon starting a new task int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. @@ -581,8 +587,6 @@ struct common_params { bool force_pure_content_parser = false; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable - int reasoning_budget = -1; - std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time @@ -747,6 +751,11 @@ inline bool string_starts_with(std::string_view str, std::string_view prefix) { str.compare(0, prefix.size(), prefix) == 0; } +// remove when moving to c++20 +inline bool string_starts_with(std::string_view str, char prefix) { + return !str.empty() && str.front() == prefix; +} + // remove when moving to c++20 inline bool string_ends_with(std::string_view str, std::string_view suffix) { return str.size() >= suffix.size() && @@ -847,7 +856,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); -std::string get_model_endpoint(); +// model endpoint from env +std::string common_get_model_endpoint(); + +// +// Context utils +// + +enum common_context_seq_rm_type { + COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) + COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences + COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only +}; + +// check if the llama_context can remove sequences +// note: clears the memory of the context +common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx); + // // Batch utils diff --git a/common/fit.cpp b/common/fit.cpp new file mode 100644 index 00000000000..4b952889070 --- /dev/null +++ b/common/fit.cpp @@ -0,0 +1,951 @@ +#include "fit.h" + +#include "log.h" + +#include "../src/llama-ext.h" + +#include +#include +#include +#include +#include +#include +#include + +// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue +// enum to identify part of a layer for distributing its tensors: +enum common_layer_fraction_t { + LAYER_FRACTION_NONE = 0, // nothing + LAYER_FRACTION_ATTN = 1, // attention + LAYER_FRACTION_UP = 2, // attention + up + LAYER_FRACTION_GATE = 3, // attention + up + gate + LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights +}; + +class common_params_fit_exception : public std::runtime_error { + using std::runtime_error::runtime_error; +}; + +static std::vector common_get_device_memory_data( + const char * path_model, + const llama_model_params * mparams, + const llama_context_params * cparams, + std::vector & devs, + uint32_t & hp_ngl, + uint32_t & hp_n_ctx_train, + uint32_t & hp_n_expert, + ggml_log_level log_level) { + struct user_data_t { + struct { + ggml_log_callback callback; + void * user_data; + } original_logger; + ggml_log_level min_level; // prints below this log level go to debug log + }; + user_data_t ud; + llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data); + ud.min_level = log_level; + + llama_log_set([](ggml_log_level level, const char * text, void * user_data) { + const user_data_t * ud = (const user_data_t *) user_data; + const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG; + ud->original_logger.callback(level_eff, text, ud->original_logger.user_data); + }, &ud); + + llama_model_params mparams_copy = *mparams; + mparams_copy.no_alloc = true; + mparams_copy.use_mmap = false; + mparams_copy.use_mlock = false; + + llama_model * model = llama_model_load_from_file(path_model, mparams_copy); + if (model == nullptr) { + llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); + throw std::runtime_error("failed to load model"); + } + + llama_context * ctx = llama_init_from_model(model, *cparams); + if (ctx == nullptr) { + llama_model_free(model); + llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); + throw std::runtime_error("failed to create llama_context from model"); + } + + const size_t nd = llama_model_n_devices(model); + std::vector ret(nd + 1); + + llama_memory_breakdown memory_breakdown = llama_get_memory_breakdown(ctx); + + for (const auto & [buft, mb] : memory_breakdown) { + if (ggml_backend_buft_is_host(buft)) { + ret.back().mb.model += mb.model; + ret.back().mb.context += mb.context; + ret.back().mb.compute += mb.compute; + continue; + } + + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (!dev) { + continue; + } + for (size_t i = 0; i < nd; i++) { + if (dev == llama_model_get_device(model, i)) { + ret[i].mb.model += mb.model; + ret[i].mb.context += mb.context; + ret[i].mb.compute += mb.compute; + break; + } + } + } + + { + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error("no CPU backend found"); + } + size_t free; + size_t total; + ggml_backend_dev_memory(cpu_dev, &free, &total); + ret.back().free = free; + ret.back().total = total; + } + for (size_t i = 0; i < nd; i++) { + size_t free; + size_t total; + ggml_backend_dev_memory(llama_model_get_device(model, i), &free, &total); + + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + free = ret.back().free; + total = ret.back().total; + } + ret[i].free = free; + ret[i].total = total; + } + + devs.clear(); + for (int i = 0; i < llama_model_n_devices(model); i++) { + devs.push_back(llama_model_get_device(model, i)); + } + + hp_ngl = llama_model_n_layer(model); + hp_n_ctx_train = llama_model_n_ctx_train(model); + hp_n_expert = llama_model_n_expert(model); + + common_memory_breakdown_print(ctx); + + llama_free(ctx); + llama_model_free(model); + llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); + + return ret; +} + +static void common_params_fit_impl( + const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, + float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, + size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + if (mparams->split_mode == LLAMA_SPLIT_MODE_TENSOR) { + throw common_params_fit_exception("llama_params_fit is not implemented for SPLIT_MODE_TENSOR, abort"); + } + constexpr int64_t MiB = 1024*1024; + typedef std::vector dmds_t; + const llama_model_params default_mparams = llama_model_default_params(); + + std::vector devs; + uint32_t hp_ngl = 0; // hparams.n_gpu_layers + uint32_t hp_nct = 0; // hparams.n_ctx_train + uint32_t hp_nex = 0; // hparams.n_expert + + // step 1: get data for default parameters and check whether any changes are necessary in the first place + + LOG_INF("%s: getting device memory data for initial parameters:\n", __func__); + const dmds_t dmds_full = common_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + const size_t nd = devs.size(); // number of devices + + std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits + margins.reserve(nd); + if (nd == 0) { + margins.push_back(margins_s[0]); + } else { + for (size_t id = 0; id < nd; id++) { + margins.push_back(margins_s[id]); + } + } + + std::vector dev_names; + { + dev_names.reserve(nd); + size_t max_length = 0; + for (const auto & dev : devs) { + std::string name = ggml_backend_dev_name(dev); + name += " ("; + name += ggml_backend_dev_description(dev); + name += ")"; + dev_names.push_back(name); + max_length = std::max(max_length, name.length()); + } + for (std::string & dn : dev_names) { + dn.insert(dn.end(), max_length - dn.length(), ' '); + } + } + + int64_t sum_free = 0; + int64_t sum_projected_free = 0; + int64_t sum_projected_used = 0; + int64_t sum_projected_model = 0; + std::vector projected_free_per_device; + projected_free_per_device.reserve(nd); + + if (nd == 0) { + sum_projected_used = dmds_full.back().mb.total(); + sum_free = dmds_full.back().total; + sum_projected_free = sum_free - sum_projected_used; + LOG_INF("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n", + __func__, sum_projected_used/MiB, sum_free/MiB); + if (sum_projected_free >= margins[0]) { + LOG_INF("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n", + __func__, sum_projected_free/MiB, margins[0]/MiB); + return; + } + } else { + if (nd > 1) { + LOG_INF("%s: projected memory use with initial parameters [MiB]:\n", __func__); + } + for (size_t id = 0; id < nd; id++) { + const llama_device_memory_data & dmd = dmds_full[id]; + + const int64_t projected_used = dmd.mb.total(); + const int64_t projected_free = dmd.free - projected_used; + projected_free_per_device.push_back(projected_free); + + sum_free += dmd.free; + sum_projected_used += projected_used; + sum_projected_free += projected_free; + sum_projected_model += dmd.mb.model; + + if (nd > 1) { + LOG_INF("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", + __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); + } + } + assert(sum_free >= 0 && sum_projected_used >= 0); + LOG_INF("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", + __func__, sum_projected_used/MiB, sum_free/MiB); + if (nd == 1) { + if (projected_free_per_device[0] >= margins[0]) { + LOG_INF("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", + __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); + return; + } + } else { + bool changes_needed = false; + for (size_t id = 0; id < nd; id++) { + if (projected_free_per_device[id] < margins[id]) { + changes_needed = true; + break; + } + } + if (!changes_needed) { + LOG_INF("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); + return; + } + } + } + + // step 2: try reducing memory use by reducing the context size + + { + int64_t global_surplus = sum_projected_free; + if (nd == 0) { + global_surplus -= margins[0]; + } else { + for (size_t id = 0; id < nd; id++) { + global_surplus -= margins[id]; + } + } + if (global_surplus < 0) { + if (nd <= 1) { + LOG_INF("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", + __func__, margins[0]/MiB, -global_surplus/MiB); + } else { + LOG_INF( + "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", + __func__, -global_surplus/MiB); + } + if (cparams->n_ctx == 0) { + if (hp_nct > n_ctx_min) { + int64_t sum_used_target = sum_free; + if (nd == 0) { + sum_used_target -= margins[0]; + } else { + for (size_t id = 0; id < nd; id++) { + sum_used_target -= margins[id]; + } + } + if (nd > 1) { + // for multiple devices we need to be more conservative in terms of how much context we think can fit: + // - for dense models only whole layers can be assigned to devices + // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer + // - on average we expect a waste of 0.5 layers/tensors per device + // - use slightly more than the expected average for nd devices to be safe + const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); + sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); + } + + int64_t sum_projected_used_min_ctx = 0; + cparams->n_ctx = n_ctx_min; + const dmds_t dmds_min_ctx = common_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + if (nd == 0) { + sum_projected_used_min_ctx = dmds_min_ctx.back().mb.total(); + } else { + for (size_t id = 0; id < nd; id++) { + sum_projected_used_min_ctx += dmds_min_ctx[id].mb.total(); + } + } + if (sum_used_target > sum_projected_used_min_ctx) { + // linear interpolation between minimum and maximum context size: + cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx) + / (sum_projected_used - sum_projected_used_min_ctx); + cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend + + const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min); + const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; + LOG_INF("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", + __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); + if (nd <= 1) { + LOG_INF("%s: entire model can be fit by reducing context\n", __func__); + return; + } + LOG_INF("%s: entire model should be fit across devices by reducing context\n", __func__); + } else { + const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx; + LOG_INF("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", + __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); + } + } else { + if (n_ctx_min == UINT32_MAX) { + LOG_INF("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct); + } else { + LOG_INF("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", + __func__, hp_nct, n_ctx_min); + } + } + } else { + LOG_INF("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); + } + } + } + if (nd == 0) { + throw common_params_fit_exception("was unable to fit model into system memory by reducing context, abort"); + } + + if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { + throw common_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); + } + if (nd > 1) { + if (!tensor_split) { + throw common_params_fit_exception("did not provide a buffer to write the tensor_split to, abort"); + } + if (mparams->tensor_split) { + for (size_t id = 0; id < nd; id++) { + if (mparams->tensor_split[id] != 0.0f) { + throw common_params_fit_exception("model_params::tensor_split already set by user, abort"); + } + } + } + if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { + throw common_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); + } + } + if (!tensor_buft_overrides) { + throw common_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort"); + } + if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { + throw common_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort"); + } + + // step 3: iteratively fill the back to front with "dense" layers + // - for a dense model simply fill full layers, giving each device a contiguous slice of the model + // - for a MoE model, same as dense model but with all MoE tensors in system memory + + // utility function that returns a static C string matching the tensors for a specific layer index and layer fraction: + auto get_overflow_pattern = [&](const size_t il, const common_layer_fraction_t lf) -> const char * { + constexpr size_t n_strings = 1000; + if (il >= n_strings) { + throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported"); + } + switch (lf) { + case LAYER_FRACTION_ATTN: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|up|gate_up|down).*"; + } + return patterns[il].c_str(); + } + case LAYER_FRACTION_UP: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|gate_up|down).*"; + } + return patterns[il].c_str(); + } + case LAYER_FRACTION_GATE: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*"; + } + return patterns[il].c_str(); + } + case LAYER_FRACTION_MOE: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate_up|gate)_(ch|)exps"; + } + return patterns[il].c_str(); + } + default: + GGML_ABORT("fatal error"); + } + }; + + struct ngl_t { + uint32_t n_layer = 0; // number of total layers + uint32_t n_part = 0; // number of partial layers, <= n_layer + + // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: + common_layer_fraction_t overflow_type = LAYER_FRACTION_MOE; + + uint32_t n_full() const { + assert(n_layer >= n_part); + return n_layer - n_part; + } + }; + + const size_t ntbo = llama_max_tensor_buft_overrides(); + + // utility function to set n_gpu_layers and tensor_split + auto set_ngl_tensor_split_tbo = [&]( + const std::vector & ngl_per_device, + const std::vector & overflow_bufts, + llama_model_params & mparams) { + mparams.n_gpu_layers = 0; + for (size_t id = 0; id < nd; id++) { + mparams.n_gpu_layers += ngl_per_device[id].n_layer; + if (nd > 1) { + tensor_split[id] = ngl_per_device[id].n_layer; + } + } + assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); + uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides + + mparams.tensor_split = tensor_split; + + size_t itbo = 0; + for (size_t id = 0; id < nd; id++) { + il0 += ngl_per_device[id].n_full(); + for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { + if (itbo + 1 >= ntbo) { + tensor_buft_overrides[itbo].pattern = nullptr; + tensor_buft_overrides[itbo].buft = nullptr; + itbo++; + mparams.tensor_buft_overrides = tensor_buft_overrides; + throw common_params_fit_exception("llama_max_tensor_buft_overrides() == " + + std::to_string(ntbo) + " is insufficient for model"); + } + tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); + tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); + itbo++; + } + il0 += ngl_per_device[id].n_part; + } + tensor_buft_overrides[itbo].pattern = nullptr; + tensor_buft_overrides[itbo].buft = nullptr; + itbo++; + mparams.tensor_buft_overrides = tensor_buft_overrides; + }; + + // utility function that returns the memory use per device for given numbers of layers per device + auto get_memory_for_layers = [&]( + const char * func_name, + const std::vector & ngl_per_device, + const std::vector & overflow_bufts) -> std::vector { + llama_model_params mparams_copy = *mparams; + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); + + const dmds_t dmd_nl = common_get_device_memory_data( + path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + + LOG_INF("%s: memory for test allocation by device:\n", func_name); + for (size_t id = 0; id < nd; id++) { + const ngl_t & n = ngl_per_device[id]; + LOG_INF( + "%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n", + func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB); + } + + std::vector ret; + ret.reserve(nd); + for (size_t id = 0; id < nd; id++) { + ret.push_back(dmd_nl[id].mb.total()); + } + return ret; + }; + + int64_t global_surplus_cpu_moe = 0; + if (hp_nex > 0) { + const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate_up|gate)_(ch|)exps"; // matches all MoE tensors + ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); + tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; + tensor_buft_overrides[1] = {nullptr, nullptr}; + mparams->tensor_buft_overrides = tensor_buft_overrides; + + LOG_INF("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__); + const dmds_t dmds_cpu_moe = common_get_device_memory_data( + path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + + for (size_t id = 0; id < nd; id++) { + global_surplus_cpu_moe += dmds_cpu_moe[id].free; + global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; + } + + if (global_surplus_cpu_moe > 0) { + LOG_INF("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n", + __func__, global_surplus_cpu_moe/MiB); + } else { + LOG_INF("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n", + __func__, -global_surplus_cpu_moe/MiB); + } + + // reset + tensor_buft_overrides[0] = {nullptr, nullptr}; + mparams->tensor_buft_overrides = tensor_buft_overrides; + } + + std::vector targets; // maximum acceptable memory use per device + targets.reserve(nd); + for (size_t id = 0; id < nd; id++) { + targets.push_back(dmds_full[id].free - margins[id]); + LOG_INF("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); + } + + std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: + overflow_bufts.reserve(nd); + for (size_t id = 0; id < nd; id++) { + overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); + } + + std::vector ngl_per_device(nd); + std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); + + // optimize the number of layers per device using the method of false position: + // - ngl_per_device has 0 layers for each device, lower bound + // - try a "high" configuration where a device is given all unassigned layers + // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target + // - check memory use of our guess, replace either the low or high bound + // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits + // - the last device has the output layer, which cannot be a partial layer + if (hp_nex == 0) { + LOG_INF("%s: filling dense layers back-to-front:\n", __func__); + } else { + LOG_INF("%s: filling dense-only layers back-to-front:\n", __func__); + } + for (int id = nd - 1; id >= 0; id--) { + uint32_t n_unassigned = hp_ngl + 1; + for (size_t jd = id + 1; jd < nd; ++jd) { + assert(n_unassigned >= ngl_per_device[jd].n_layer); + n_unassigned -= ngl_per_device[jd].n_layer; + } + + std::vector ngl_per_device_high = ngl_per_device; + ngl_per_device_high[id].n_layer = n_unassigned; + if (hp_nex > 0) { + ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; + } + if (ngl_per_device_high[id].n_layer > 0) { + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); + if (mem_high[id] > targets[id]) { + assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); + uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; + LOG_INF("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); + while (delta > 1) { + uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); + step_size = std::max(step_size, uint32_t(1)); + step_size = std::min(step_size, delta - 1); + + std::vector ngl_per_device_test = ngl_per_device; + ngl_per_device_test[id].n_layer += step_size; + if (hp_nex) { + ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? + step_size - 1 : step_size; // the first layer is the output layer which must always be full + } + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + + if (mem_test[id] <= targets[id]) { + ngl_per_device = ngl_per_device_test; + mem = mem_test; + LOG_INF("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); + } else { + ngl_per_device_high = ngl_per_device_test; + mem_high = mem_test; + LOG_INF("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer); + } + delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; + } + } else { + assert(ngl_per_device_high[id].n_layer == n_unassigned); + ngl_per_device = ngl_per_device_high; + mem = mem_high; + LOG_INF("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); + } + } + + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LOG_INF( + "%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); + } + if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); + return; + } + + // step 4: for a MoE model where all dense tensors fit, + // convert the dense-only layers in the back to full layers in the front until all devices are full + // essentially the same procedure as for the dense-only layers except front-to-back + // also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM + + size_t id_dense_start = nd; + for (int id = nd - 1; id >= 0; id--) { + if (ngl_per_device[id].n_layer > 0) { + id_dense_start = id; + continue; + } + break; + } + assert(id_dense_start < nd); + + LOG_INF("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); + for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { + std::vector ngl_per_device_high = ngl_per_device; + for (size_t jd = id_dense_start; jd < nd; jd++) { + const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; + ngl_per_device_high[id].n_layer += n_layer_move; + ngl_per_device_high[jd].n_layer -= n_layer_move; + ngl_per_device_high[jd].n_part = 0; + } + size_t id_dense_start_high = nd - 1; + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); + + if (mem_high[id] > targets[id]) { + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); + while (delta > 1) { + uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); + step_size = std::max(step_size, uint32_t(1)); + step_size = std::min(step_size, delta - 1); + + std::vector ngl_per_device_test = ngl_per_device; + size_t id_dense_start_test = id_dense_start; + uint32_t n_converted_test = 0; + for (;id_dense_start_test < nd; id_dense_start_test++) { + const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part); + ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd; + ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd; + ngl_per_device_test[id].n_layer += n_convert_jd; + n_converted_test += n_convert_jd; + + if (ngl_per_device_test[id_dense_start_test].n_part > 0) { + break; + } + } + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + + if (mem_test[id] <= targets[id]) { + ngl_per_device = ngl_per_device_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LOG_INF("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } else { + ngl_per_device_high = ngl_per_device_test; + mem_high = mem_test; + id_dense_start_high = id_dense_start_test; + LOG_INF("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", + __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); + } + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); + } + } else { + ngl_per_device = ngl_per_device_high; + mem = mem_high; + id_dense_start = id_dense_start_high; + LOG_INF("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } + + // try to fit at least part of one more layer + if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { + std::vector ngl_per_device_test = ngl_per_device; + size_t id_dense_start_test = id_dense_start; + ngl_per_device_test[id_dense_start_test].n_layer--; + ngl_per_device_test[id_dense_start_test].n_part--; + ngl_per_device_test[id].n_layer++; + ngl_per_device_test[id].n_part++; + if (ngl_per_device_test[id_dense_start_test].n_part == 0) { + id_dense_start_test++; + } + ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; + std::vector overflow_bufts_test = overflow_bufts; + if (id < nd - 1) { + overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); + } + LOG_INF("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); + if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { + ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LOG_INF("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + + ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; + LOG_INF("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); + if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { + ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LOG_INF("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } + } else { + ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; + LOG_INF("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); + if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { + ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LOG_INF("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } + } + } + + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LOG_INF( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + + // print info for devices that were not changed during the conversion from dense only to full layers: + for (size_t id = id_dense_start + 1; id < nd; id++) { + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LOG_INF( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); +} + +enum common_params_fit_status common_fit_params( + const char * path_model, + llama_model_params * mparams, + llama_context_params * cparams, + float * tensor_split, + llama_model_tensor_buft_override * tensor_buft_overrides, + size_t * margins, + uint32_t n_ctx_min, + ggml_log_level log_level) { + const int64_t t0_us = llama_time_us(); + common_params_fit_status status = COMMON_PARAMS_FIT_STATUS_SUCCESS; + try { + common_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); + LOG_INF("%s: successfully fit params to free device memory\n", __func__); + } catch (const common_params_fit_exception & e) { + LOG_WRN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); + status = COMMON_PARAMS_FIT_STATUS_FAILURE; + } catch (const std::runtime_error & e) { + LOG_ERR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what()); + status = COMMON_PARAMS_FIT_STATUS_ERROR; + } + const int64_t t1_us = llama_time_us(); + LOG_INF("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); + return status; +} + +void common_memory_breakdown_print(const struct llama_context * ctx) { + //const auto & devices = ctx->get_model().devices; + const auto * model = llama_get_model(ctx); + + std::vector devices; + for (int i = 0; i < llama_model_n_devices(model); i++) { + devices.push_back(llama_model_get_device(model, i)); + } + + llama_memory_breakdown memory_breakdown = llama_get_memory_breakdown(ctx); + + std::vector> table_data; + table_data.reserve(devices.size()); + const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; + const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; + const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; + + table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); + + constexpr size_t MiB = 1024 * 1024; + const std::vector desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; + + // track seen buffer types to avoid double counting: + std::set seen_buffer_types; + + // accumulative memory breakdown for each device and for host: + std::vector mb_dev(devices.size()); + llama_memory_breakdown_data mb_host; + + for (const auto & buft_mb : memory_breakdown) { + ggml_backend_buffer_type_t buft = buft_mb.first; + const llama_memory_breakdown_data & mb = buft_mb.second; + if (ggml_backend_buft_is_host(buft)) { + mb_host.model += mb.model; + mb_host.context += mb.context; + mb_host.compute += mb.compute; + seen_buffer_types.insert(buft); + continue; + } + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (dev) { + int i_dev = -1; + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i] == dev) { + i_dev = i; + break; + } + } + if (i_dev != -1) { + mb_dev[i_dev].model += mb.model; + mb_dev[i_dev].context += mb.context; + mb_dev[i_dev].compute += mb.compute; + seen_buffer_types.insert(buft); + continue; + } + } + } + + // print memory breakdown for each device: + for (size_t i = 0; i < devices.size(); i++) { + ggml_backend_dev_t dev = devices[i]; + llama_memory_breakdown_data mb = mb_dev[i]; + + const std::string name = ggml_backend_dev_name(dev); + std::string desc = ggml_backend_dev_description(dev); + for (const std::string & prefix : desc_prefixes_strip) { + if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { + desc = desc.substr(prefix.length()); + } + } + + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + + const size_t self = mb.model + mb.context + mb.compute; + const size_t unaccounted = total - self - free; + + table_data.push_back({ + template_gpu, + " - " + name + " (" + desc + ")", + std::to_string(total / MiB), + std::to_string(free / MiB), + std::to_string(self / MiB), + std::to_string(mb.model / MiB), + std::to_string(mb.context / MiB), + std::to_string(mb.compute / MiB), + std::to_string(unaccounted / MiB)}); + } + + // print memory breakdown for host: + { + const size_t self = mb_host.model + mb_host.context + mb_host.compute; + table_data.push_back({ + template_other, + " - Host", + "", // total + "", // free + std::to_string(self / MiB), + std::to_string(mb_host.model / MiB), + std::to_string(mb_host.context / MiB), + std::to_string(mb_host.compute / MiB), + ""}); // unaccounted + } + + // print memory breakdown for all remaining buffer types: + for (const auto & buft_mb : memory_breakdown) { + ggml_backend_buffer_type_t buft = buft_mb.first; + const llama_memory_breakdown_data & mb = buft_mb.second; + if (seen_buffer_types.count(buft) == 1) { + continue; + } + const std::string name = ggml_backend_buft_name(buft); + const size_t self = mb.model + mb.context + mb.compute; + table_data.push_back({ + template_other, + " - " + name, + "", // total + "", // free + std::to_string(self / MiB), + std::to_string(mb.model / MiB), + std::to_string(mb.context / MiB), + std::to_string(mb.compute / MiB), + ""}); // unaccounted + seen_buffer_types.insert(buft); + } + + for (size_t j = 1; j < table_data[0].size(); j++) { + size_t max_len = 0; + for (const auto & td : table_data) { + max_len = std::max(max_len, td[j].length()); + } + for (auto & td : table_data) { + td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); + } + } + for (const auto & td : table_data) { + LOG_INF(td[0].c_str(), + __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), + td[6].c_str(), td[7].c_str(), td[8].c_str()); + } +} + +void common_fit_print( + const char * path_model, + llama_model_params * mparams, + llama_context_params * cparams) { + std::vector devs; + uint32_t hp_ngl = 0; // hparams.n_gpu_layers + uint32_t hp_nct = 0; // hparams.n_ctx_train + uint32_t hp_nex = 0; // hparams.n_expert + + auto dmd = common_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, GGML_LOG_LEVEL_ERROR); + GGML_ASSERT(dmd.size() == devs.size() + 1); + + for (size_t id = 0; id < devs.size(); id++) { + printf("%s ", ggml_backend_dev_name(devs[id])); + printf("%zu ", dmd[id].mb.model/1024/1024); + printf("%zu ", dmd[id].mb.context/1024/1024); + printf("%zu ", dmd[id].mb.compute/1024/1024); + printf("\n"); + } + + printf("Host "); + printf("%zu ", dmd.back().mb.model/1024/1024); + printf("%zu ", dmd.back().mb.context/1024/1024); + printf("%zu ", dmd.back().mb.compute/1024/1024); + printf("\n"); +} diff --git a/common/fit.h b/common/fit.h new file mode 100644 index 00000000000..e066092ec6c --- /dev/null +++ b/common/fit.h @@ -0,0 +1,32 @@ +#pragma once + +#include "ggml.h" + +enum common_params_fit_status { + COMMON_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit + COMMON_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit + COMMON_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path +}; + +// fits mparams and cparams to free device memory (assumes system memory is unlimited) +// - returns true if the parameters could be successfully modified to fit device memory +// - this function is NOT thread safe because it modifies the global llama logger state +// - only parameters that have the same value as in llama_default_model_params are modified +// with the exception of the context size which is modified if and only if equal to 0 +enum common_params_fit_status common_fit_params( + const char * path_model, + struct llama_model_params * mparams, + struct llama_context_params * cparams, + float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements + struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements + size_t * margins, // margins of memory to leave per device in bytes + uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use + enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log + +// print estimated memory to stdout +void common_fit_print( + const char * path_model, + struct llama_model_params * mparams, + struct llama_context_params * cparams); + +void common_memory_breakdown_print(const struct llama_context * ctx); diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp index 38a4c17a98e..ea5b2150de4 100644 --- a/common/hf-cache.cpp +++ b/common/hf-cache.cpp @@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url, static std::string get_repo_commit(const std::string & repo_id, const std::string & token) { try { - auto endpoint = get_model_endpoint(); + auto endpoint = common_get_model_endpoint(); auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token); if (!json.is_object() || @@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id, hf_files files; try { - auto endpoint = get_model_endpoint(); + auto endpoint = common_get_model_endpoint(); auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token); if (!json.is_array()) { diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index ec207a53e85..ead864763e1 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -1,4 +1,3 @@ -#include "log.h" #include "value.h" #include "runtime.h" #include "caps.h" diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index 3ca5f1754fa..b6f4a6ab48e 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -106,10 +106,16 @@ struct statement { size_t pos; // position in source, for debugging virtual ~statement() = default; virtual std::string type() const { return "Statement"; } + // execute_impl must be overridden by derived classes - virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); } + virtual value execute_impl(context &) { throw_exec_error(); } // execute is the public method to execute a statement with error handling value execute(context &); + +private: + [[noreturn]] void throw_exec_error() const { + throw std::runtime_error("cannot exec " + type()); + } }; // Type Checking Utilities @@ -143,7 +149,7 @@ struct program : public statement { program() = default; explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw std::runtime_error("Cannot execute program directly, use jinja::runtime instead"); } }; @@ -195,7 +201,7 @@ struct break_statement : public statement { } }; - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw break_statement::signal(); } }; @@ -209,7 +215,7 @@ struct continue_statement : public statement { } }; - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw continue_statement::signal(); } }; @@ -509,7 +515,7 @@ struct slice_expression : public expression { chk_type(this->step_expr); } std::string type() const override { return "SliceExpression"; } - value execute_impl(context &) override { + [[noreturn]] value execute_impl(context &) override { throw std::runtime_error("must be handled by MemberExpression"); } }; diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index 8e86a715f5f..0b79098cd1e 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -590,6 +590,10 @@ static bool string_endswith(const std::string & str, const std::string & suffix) return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0; } +[[noreturn]] static value string_join_not_implemented(const func_args &) { + throw not_implemented_exception("String join builtin not implemented"); +} + const func_builtins & value_string_t::get_builtins() const { static const func_builtins builtins = { {"default", default_value}, @@ -851,9 +855,7 @@ const func_builtins & value_string_t::get_builtins() const { res->val_str.mark_input_based_on(val_input->as_string()); return res; }}, - {"join", [](const func_args &) -> value { - throw not_implemented_exception("String join builtin not implemented"); - }}, + {"join", string_join_not_implemented}, }; return builtins; } @@ -884,6 +886,9 @@ const func_builtins & value_bool_t::get_builtins() const { return builtins; } +[[noreturn]] static value array_unique_not_implemented(const func_args &) { + throw not_implemented_exception("Array unique builtin not implemented"); +} const func_builtins & value_array_t::get_builtins() const { static const func_builtins builtins = { @@ -1084,13 +1089,14 @@ const func_builtins & value_array_t::get_builtins() const { std::reverse(arr.begin(), arr.end()); return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, - {"unique", [](const func_args &) -> value { - throw not_implemented_exception("Array unique builtin not implemented"); - }}, + {"unique", array_unique_not_implemented}, }; return builtins; } +[[noreturn]] static value object_join_not_implemented(const func_args &) { + throw not_implemented_exception("object join not implemented"); +} const func_builtins & value_object_t::get_builtins() const { if (!has_builtins) { @@ -1183,9 +1189,7 @@ const func_builtins & value_object_t::get_builtins() const { }); return result; }}, - {"join", [](const func_args &) -> value { - throw not_implemented_exception("object join not implemented"); - }}, + {"join", object_join_not_implemented}, }; return builtins; } diff --git a/common/jinja/value.h b/common/jinja/value.h index 7d164588ad9..5cf85e4f544 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -129,27 +129,25 @@ struct value_t { // Note: only for debugging and error reporting purposes virtual std::string type() const { return ""; } - virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } - virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); } - virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } - virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } - virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } - virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } - virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } + virtual int64_t as_int() const { throw_type_error("is not an int value"); } + virtual double as_float() const { throw_type_error("is not a float value"); } + virtual string as_string() const { throw_type_error("is not a string value"); } + virtual bool as_bool() const { throw_type_error("is not a bool value"); } + virtual const std::vector & as_array() const { throw_type_error("is not an array value"); } + virtual const std::vector> & as_ordered_object() const { throw_type_error("is not an object value"); } + virtual value invoke(const func_args &) const { throw_type_error("is not a function value"); } virtual bool is_none() const { return false; } virtual bool is_undefined() const { return false; } - virtual const func_builtins & get_builtins() const { - throw std::runtime_error("No builtins available for type " + type()); - } + virtual const func_builtins & get_builtins() const { throw_type_error("has no builtins"); } - virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); } - virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); } - virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); } - virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); } + virtual bool has_key(const value &) { throw_type_error("is not an object value"); } + virtual void insert(const value & /* key */, const value & /* val */) { throw_type_error("is not an object value"); } + virtual value & at(const value & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); } + virtual value & at(const value & /* key */) { throw_type_error("is not an object value"); } + virtual value & at(const std::string & /* key */, value & /* default_val */) { throw_type_error("is not an object value"); } + virtual value & at(const std::string & /* key */) { throw_type_error("is not an object value"); } + virtual value & at(int64_t /* idx */, value & /* default_val */) { throw_type_error("is not an array value"); } + virtual value & at(int64_t /* idx */) { throw_type_error("is not an array value"); } virtual bool is_numeric() const { return false; } virtual bool is_hashable() const { return false; } @@ -163,6 +161,11 @@ struct value_t { // Note: only for debugging purposes virtual std::string as_repr() const { return as_string().str(); } +private: + [[noreturn]] void throw_type_error(const char* expected) const { + throw std::runtime_error(type() + " " + expected); + } + protected: virtual bool equivalent(const value_t &) const = 0; virtual bool nonequal(const value_t & other) const { return !equivalent(other); } diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index ebf771a24a7..8e3978f7ed0 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -208,7 +208,7 @@ void common_ngram_map_begin( count_keys, count_keys_del, count_values_del, count_map_entries_upd); } - map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; + map.idx_last_check = size_begin; map.size_last_begin = size_begin; } @@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map, GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); } - if (map.idx_last_check > cur_len) { + if (map.idx_last_check > cur_len) { // Should not happen because of common_ngram_map_begin(). GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len); } @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map, LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); - map.last_draft_created = false; + map.last_draft_created = true; map.last_draft_key_idx = key_offset; map.last_draft_value_idx = 0; // value 0 is used for simple mode return; @@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) { struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. // update the value statistics - LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", + LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", n_accepted, curr_value.n_accepted); curr_value.n_accepted = n_accepted; } diff --git a/common/sampling.cpp b/common/sampling.cpp index 526f036ff98..b2e6d8e8d89 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,10 +1,12 @@ #include "sampling.h" #include "common.h" -#include "ggml.h" +#include "fit.h" #include "log.h" #include "reasoning-budget.h" +#include "ggml.h" + #include #include #include @@ -511,7 +513,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc); LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused); - llama_memory_breakdown_print(ctx); + common_memory_breakdown_print(ctx); } } diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e68c38e49c..4980c03da62 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -21,6 +22,7 @@ const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, COMMON_SPECULATIVE_TYPE_EAGLE3, + COMMON_SPECULATIVE_TYPE_DFLASH, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, @@ -32,6 +34,7 @@ const std::map common_speculative_typ {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, + {"dflash", COMMON_SPECULATIVE_TYPE_DFLASH}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -47,6 +50,7 @@ struct common_speculative_config { const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {} }; + static bool common_speculative_are_compatible( const llama_model * model_tgt, const llama_model * model_dft) { @@ -144,10 +148,28 @@ struct common_speculative_state { virtual void accept(uint16_t n_accepted) = 0; }; +struct common_speculative_checkpoint { + llama_pos pos_min = 0; + llama_pos pos_max = 0; + + int64_t n_tokens = 0; + + std::vector data; + + size_t size() const { + return data.size(); + } + + size_t ckpt_size = 0; +}; + struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; + bool use_ckpt = false; + struct common_speculative_checkpoint ckpt; + common_sampler * smpl; llama_batch batch; @@ -160,10 +182,12 @@ struct common_speculative_state_draft : public common_speculative_state { enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - const std::vector> & replacements) + const std::vector> & replacements, + bool use_ckpt) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) + , use_ckpt(use_ckpt) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; @@ -210,7 +234,9 @@ struct common_speculative_state_draft : public common_speculative_state { ~common_speculative_state_draft() override { llama_perf_context_print(ctx_dft); - llama_free(ctx_dft); + if (ctx_dft) { + llama_free(ctx_dft); + } common_sampler_free(smpl); @@ -218,7 +244,48 @@ struct common_speculative_state_draft : public common_speculative_state { } void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + if (use_ckpt && ckpt.size() > 0) { + // delete checkpoint + LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n", + __func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); + ckpt.pos_min = 0; + ckpt.pos_max = 0; + ckpt.n_tokens = 0; + ckpt.ckpt_size = 0; + ckpt.data.clear(); + } + } + + size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) { + int slot_id = 0; + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); + ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); + ckpt.n_tokens = n_tokens_prompt - n_tokens_batch; + ckpt.data.resize(checkpoint_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } + + LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, + ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); + return n; + } + + size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) { + int slot_id = 0; + LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt_size_part_expected) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); + } + llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); + + return n; } void draft( @@ -228,16 +295,16 @@ struct common_speculative_state_draft : public common_speculative_state { llama_tokens & result) override { auto * spec = this; - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; + auto & prompt_dft = spec->prompt_dft; auto * mem_dft = llama_get_memory(ctx_dft); - int reuse_i = 0; - int reuse_n = 0; + int reuse_i = 0; // index of part to be reused in prompt_dft + int reuse_n = 0; // length of part to be reused in prompt_dft const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; @@ -287,18 +354,26 @@ struct common_speculative_state_draft : public common_speculative_state { } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", + __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); + if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) { + LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", + __func__, reuse_i, reuse_n); + reuse_i = 0; + reuse_n = 0; + } result.clear(); result.reserve(params.n_max); - if (reuse_n == 0) { + bool needs_ckpt = use_ckpt && prompt_dft.size() > 0; + if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); } else { // this happens when a previous draft has been discarded (for example, due to being too small), but the // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { result.push_back(prompt_dft[i]); @@ -310,19 +385,50 @@ struct common_speculative_state_draft : public common_speculative_state { return; } + bool do_restore = false; + if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) { + // This can happen after a partial acceptance (speculative decoding with checkpoints) + LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n", + __func__, prompt_dft.size(), prompt_cur.size()); + prompt_dft.resize(prompt_cur.size()); + do_restore = true; + } + if (reuse_i > 0) { - llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); + } llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); } - if (reuse_n < (int) prompt_dft.size()) { - llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + if (reuse_n < (int) prompt_dft.size() || do_restore) { + if (use_ckpt) { + if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { + LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n", + __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); + } + draft_restore_checkpoint(ckpt.ckpt_size); + reuse_n = ckpt.n_tokens; + prompt_dft.resize(reuse_n); + needs_ckpt = false; + } else { + bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", + __func__, reuse_n, prompt_dft.size()); + } + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } } } + if (needs_ckpt) { + ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens); + } + // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); @@ -337,7 +443,11 @@ struct common_speculative_state_draft : public common_speculative_state { if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", + __func__, ret, prompt_cur.size()); + } } const llama_pos n_past = prompt_dft.size(); @@ -351,7 +461,11 @@ struct common_speculative_state_draft : public common_speculative_state { LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, ret, prompt_cur.size(), prompt_dft.size()); + } common_sampler_reset(smpl); @@ -387,7 +501,11 @@ struct common_speculative_state_draft : public common_speculative_state { common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + } prompt_dft.push_back(id); } @@ -438,7 +556,52 @@ struct common_speculative_state_draft : public common_speculative_state { }; struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} + llama_context * ctx_tgt; + + common_sampler * smpl; + + llama_batch batch; + + struct llama_context * ctx_dft_enc = nullptr; + struct llama_context * ctx_dft_dec = nullptr; + + int32_t eagle3_n_past = 0; // number of verified positions in decoder KV cache + + common_speculative_state_eagle3( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft_enc, + llama_context * ctx_dft_dec) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft_enc(ctx_dft_enc) + , ctx_dft_dec(ctx_dft_dec) + { + batch = llama_batch_init(llama_n_batch(ctx_dft_dec), 0, 1); + + // Initialize sampler for EAGLE3 decoder + common_params_sampling params; + params.no_perf = false; + params.top_k = 10; // set 1 for greedy sampling (argmax) to match vLLM's default behavior but >1 always gets higher acceptance rate for eagle3 + params.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + smpl = common_sampler_init(llama_get_model(ctx_dft_dec), params); + } + + ~common_speculative_state_eagle3() override { + llama_perf_context_print(ctx_dft_dec); + + if (ctx_dft_dec) { + llama_free(ctx_dft_dec); + } + + if (ctx_dft_enc) { + llama_free(ctx_dft_enc); + } + + common_sampler_free(smpl); + + llama_batch_free(batch); + } void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); @@ -448,12 +611,97 @@ struct common_speculative_state_eagle3 : public common_speculative_state { const common_params_speculative & params, const llama_tokens & prompt_tgt, llama_token id_last, - llama_tokens & draft_tokens) override { - // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); + llama_tokens & result) override { + auto * spec = this; + + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft_enc = spec->ctx_dft_enc; + auto & ctx_dft_dec = spec->ctx_dft_dec; + auto & smpl = spec->smpl; + + //result = gen_eagle3_draft(spec, params, prompt_tgt, id_last); + const int n_embd = llama_model_n_embd(llama_get_model(ctx_dft_enc)); + const int n = (int)prompt_tgt.size(); + const int n_new = n - spec->eagle3_n_past; + + GGML_ASSERT(n >= 1 && "prompt_tgt is empty"); + GGML_ASSERT(n_new >= 1 && "must have at least 1 new token"); + + // Clear draft positions from decoder KV cache [n_past, inf) + llama_memory_seq_rm(llama_get_memory(ctx_dft_dec), 0, spec->eagle3_n_past, -1); + + // Encoder: features → g_embeddings + const float * features = llama_get_eagle3_target_features(ctx_tgt); + GGML_ASSERT(features && "no target features"); + + llama_batch enc_batch = { + /*.n_tokens =*/ n_new, + /*.token =*/ nullptr, + /*.embd =*/ const_cast(features), + /*.pos =*/ nullptr, + /*.n_seq_id =*/ nullptr, + /*.seq_id =*/ nullptr, + /*.logits =*/ nullptr, + }; + GGML_ASSERT(llama_encode(ctx_dft_enc, enc_batch) == 0); + + const float * g_embd = llama_get_embeddings(ctx_dft_enc); + GGML_ASSERT(g_embd && "encoder output failed"); + + // Decoder batch: process new tokens with KV cache reuse + llama_set_eagle3_g_embeddings(ctx_dft_dec, g_embd, n_embd, n_new); + + common_batch_clear(batch); + for (int i = 0; i < n_new; i++) { + const int pos = spec->eagle3_n_past + i; + const llama_token tok = (pos < n - 1) ? prompt_tgt[pos + 1] : id_last; + common_batch_add(batch, tok, pos, {0}, true); + } + + GGML_ASSERT(llama_decode(ctx_dft_dec, batch) == 0); + + spec->eagle3_n_past = n; // update verified positions + + // Sample draft tokens + result.clear(); + common_sampler_reset(smpl); + + // Sample and check probability (consistent with standard speculative decoding) + auto sample_and_check = [&](int idx) -> bool { + common_sampler_sample(smpl, ctx_dft_dec, idx); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + result.push_back(id); + + return cur_p->data[0].p >= params.p_min; + }; + + // First draft token from batch decode + if (!sample_and_check(n_new - 1)) { + return; + } + + // Autoregressive: use prenorm as g_embd (-1 = last output) + const float * prenorm = llama_get_embeddings_ith(ctx_dft_dec, -1); + + for (int i = 1; i < params.n_max; i++) { + GGML_ASSERT(prenorm && "prenorm failed"); + llama_set_eagle3_g_embeddings(ctx_dft_dec, prenorm, n_embd, 1); + + common_batch_clear(batch); + common_batch_add(batch, result.back(), n - 1 + i, {0}, true); + GGML_ASSERT(llama_decode(ctx_dft_dec, batch) == 0); + + prenorm = llama_get_embeddings_ith(ctx_dft_dec, -1); + + if (!sample_and_check(0)) { + break; + } + } } void accept(uint16_t n_accepted) override { @@ -462,6 +710,139 @@ struct common_speculative_state_eagle3 : public common_speculative_state { } }; +struct common_speculative_state_dflash : public common_speculative_state { + llama_context * ctx_tgt; + + common_sampler * smpl; + + llama_batch batch; + + struct llama_context * ctx_dft_enc = nullptr; + struct llama_context * ctx_dft_dec = nullptr; + + int32_t dflash_n_past = 0; + + // Host-side buffer: accumulated DFlash-encoded target features across all + // committed prompt+drafted tokens. Grows by `n_new * n_embd` floats per draft step + // and is fed to the DFlash decoder via llama_set_dflash_accumulated_target_ctx() + std::vector accumulated_ctx; + + common_speculative_state_dflash( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft_enc, + llama_context * ctx_dft_dec) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft_enc(ctx_dft_enc) + , ctx_dft_dec(ctx_dft_dec) + { + batch = llama_batch_init(llama_n_batch(ctx_dft_dec), 0, 1); + + common_params_sampling params; + params.no_perf = false; + params.top_k = 1; + params.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + smpl = common_sampler_init(llama_get_model(ctx_dft_dec), params); + } + + ~common_speculative_state_dflash() override { + llama_perf_context_print(ctx_dft_dec); + + if (ctx_dft_dec) { + llama_free(ctx_dft_dec); + } + + if (ctx_dft_enc) { + llama_free(ctx_dft_enc); + } + + common_sampler_free(smpl); + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + const int n_embd = llama_model_n_embd(llama_get_model(ctx_dft_dec)); + // block_size is bounded by the model's trained block_size (from GGUF metadata). + const int model_block_size = llama_model_dflash_block_size(llama_get_model(ctx_dft_dec)); + const int block_size = std::min((int)params.n_max, model_block_size); + const int n = (int)prompt_tgt.size(); + const int n_new = n - dflash_n_past; + + GGML_ASSERT(n >= 1 && "prompt_tgt is empty"); + GGML_ASSERT(n_new >= 1 && "must have at least 1 new token"); + + // Step 1: Encode new accepted tokens' features + const float * features = llama_get_dflash_target_features(ctx_tgt); + + llama_batch enc_batch = { + /*.n_tokens =*/ n_new, + /*.token =*/ nullptr, + /*.embd =*/ const_cast(features), + /*.pos =*/ nullptr, + /*.n_seq_id =*/ nullptr, + /*.seq_id =*/ nullptr, + /*.logits =*/ nullptr, + }; + if (llama_encode(ctx_dft_enc, enc_batch) != 0) { + LOG_ERR("DFlash: encoder failed\n"); + return; + } + + const float * target_ctx_new = llama_get_embeddings(ctx_dft_enc); + GGML_ASSERT(target_ctx_new && "encoder output is null"); + + // Step 2: Append to accumulated target_ctx and set on decoder context (writes to cross.v_embd) + const size_t new_size = (size_t)n_embd * n_new; + accumulated_ctx.insert(accumulated_ctx.end(), target_ctx_new, target_ctx_new + new_size); + + const int n_ctx_total = (int)(accumulated_ctx.size() / n_embd); + llama_set_dflash_accumulated_target_ctx(ctx_dft_dec, accumulated_ctx.data(), n_embd, n_ctx_total); + + // Step 3: Decode noise block + const llama_token mask_token_id = llama_model_dflash_mask_token_id(llama_get_model(ctx_dft_dec)); + + common_batch_clear(batch); + for (int i = 0; i < block_size; i++) { + const llama_token tok = (i == 0) ? id_last : mask_token_id; + common_batch_add(batch, tok, i, {0}, true); + } + + if (llama_decode(ctx_dft_dec, batch) != 0) { + LOG_ERR("DFlash: noise decode failed\n"); + return; + } + + dflash_n_past = n; + + // Step 4: Sample draft tokens from positions 1..block_size-1 + result.clear(); + common_sampler_reset(smpl); + + for (int i = 1; i < block_size; i++) { + common_sampler_sample(smpl, ctx_dft_dec, i); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + result.push_back(id); + } + } + + void accept(uint16_t n_accepted) override { + GGML_UNUSED(n_accepted); + } +}; + // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config; @@ -636,6 +1017,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { mod.reset(); n_low = 0; + i_last = 0; } } else { n_low = 0; @@ -739,6 +1121,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { struct common_speculative { std::vector> impls; // list of implementations to use and their states + common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) }; @@ -798,53 +1181,41 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } -bool common_speculative_is_compat(llama_context * ctx_tgt) { - auto * mem = llama_get_memory(ctx_tgt); - if (mem == nullptr) { - return false; - } - - bool res = true; - - llama_memory_clear(mem, true); - - // eval 2 tokens to check if the context is compatible - std::vector tmp; - tmp.push_back(0); - tmp.push_back(0); - - int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); - if (ret != 0) { - LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = false; - goto done; - } - - // try to remove the last tokens - if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; - goto done; - } - -done: - llama_memory_clear(mem, true); - llama_synchronize(ctx_tgt); - - return res; -} - // initialization of the speculative decoding system // common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt) { llama_context * ctx_dft = nullptr; + + llama_context * ctx_dft_enc = nullptr; + llama_context * ctx_dft_dec = nullptr; + if (params.model_dft) { - ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; + if (params.eagle3 || params.dflash) { + llama_context_params params_enc = params.cparams_dft; + params_enc.target_model = nullptr; + params_enc.embeddings = true; + ctx_dft_enc = llama_init_from_model(params.model_dft, params_enc); + if (!ctx_dft_enc) { + LOG_ERR("failed to create %s draft model encoder context\n", params.eagle3 ? "EAGLE3" : "DFlash"); + return nullptr; + } + + llama_context_params params_dec = params.cparams_dft; + params_dec.target_model = params.model_tgt; + params_dec.embeddings = true; + ctx_dft_dec = llama_init_from_model(params.model_dft, params_dec); + if (!ctx_dft_dec) { + LOG_ERR("failed to create %s draft model decoder context\n", params.eagle3 ? "EAGLE3" : "DFlash"); + return nullptr; + } + } else { + ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); + if (ctx_dft == nullptr) { + LOG_ERR("failed to create draft model context\n"); + return nullptr; + } } } @@ -852,7 +1223,8 @@ common_speculative * common_speculative_init( std::vector configs = {}; // list of speculative configs to try { bool has_draft = !params.mparams_dft.path.empty(); - bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_draft_eagle3 = params.eagle3; + bool has_draft_dflash = params.dflash; bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); @@ -893,10 +1265,13 @@ common_speculative * common_speculative_init( configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } if (has_draft) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); - } - if (has_draft_eagle3) { - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); + if (has_draft_eagle3) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); + } else if (has_draft_dflash) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DFLASH, params)); + } else { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); + } } } @@ -908,15 +1283,30 @@ common_speculative * common_speculative_init( case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { + const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + impls.push_back(std::make_unique(config.type, /* .ctx_tgt = */ ctx_tgt, /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.replacements + /* .replacements = */ params.replacements, + /* .use_ckpt = */ use_ckpt )); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft_enc = */ ctx_dft_enc, + /* .ctx_dft_dec = */ ctx_dft_dec + )); + break; + } + case COMMON_SPECULATIVE_TYPE_DFLASH: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft_enc = */ ctx_dft_enc, + /* .ctx_dft_dec = */ ctx_dft_dec + )); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { @@ -966,7 +1356,8 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { - /* .impls = */ std::move(impls) + /* .impls = */ std::move(impls), + /* .curr_impl = */ nullptr, }; return result; diff --git a/common/speculative.h b/common/speculative.h index 876cde3d180..bca78d32b5b 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,10 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -// check if the llama_context is compatible for speculative decoding -// note: clears the memory of the context -bool common_speculative_is_compat(llama_context * ctx_tgt); - common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt); @@ -39,3 +35,9 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); + +struct common_speculative_deleter { + void operator()(common_speculative * s) { common_speculative_free(s); } +}; + +typedef std::unique_ptr common_speculative_ptr; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 42d559dfecf..6a5ac25d945 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -97,6 +97,7 @@ class ModelBase: metadata_override: Path | None dir_model_card: Path remote_hf_model_id: str | None + target_model_dir: Path | None # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -116,7 +117,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, disable_mistral_community_chat_template: bool = False, - sentence_transformers_dense_modules: bool = False, + sentence_transformers_dense_modules: bool = False, target_model_dir: Path | None = None, fuse_gate_up_exps: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ @@ -136,6 +137,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules + self.target_model_dir = target_model_dir self.fuse_gate_up_exps = fuse_gate_up_exps self._gate_exp_buffer: dict[int, Tensor] = {} self._up_exp_buffer: dict[int, Tensor] = {} @@ -746,7 +748,12 @@ def prepare_tensors(self): if (not quant_algo or not quant_layers) and quant_config_file.is_file(): with open(quant_config_file, "r", encoding="utf-8") as f: - quant_config = json.load(f).get("quantization") or {} + hf_quant_config = json.load(f) + quant_config = hf_quant_config.get("quantization") or {} + producer = hf_quant_config.get("producer") or {} + producer_name = (producer.get("name") or "").lower() + if quant_method is None: + self.hparams.setdefault("quantization_config", {})["quant_method"] = producer_name quant_algo = quant_config.get("quant_algo", quant_algo) quant_layers = quant_config.get("quantized_layers", quant_layers) or {} @@ -1850,20 +1857,28 @@ def _try_set_pooling_type(self) -> None: with open(module_path, encoding="utf-8") as f: modules = json.load(f) for mod in modules: - if mod["type"] == "sentence_transformers.models.Pooling": + if mod["type"].endswith("Pooling"): pooling_path = mod["path"] break + mode_mapping = { + "mean": gguf.PoolingType.MEAN, + "cls": gguf.PoolingType.CLS, + "lasttoken": gguf.PoolingType.LAST, + } + # get pooling type if pooling_path is not None: with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: pooling = json.load(f) - if pooling["pooling_mode_mean_tokens"]: + if pooling.get("pooling_mode_mean_tokens"): pooling_type = gguf.PoolingType.MEAN - elif pooling["pooling_mode_cls_token"]: + elif pooling.get("pooling_mode_cls_token"): pooling_type = gguf.PoolingType.CLS - elif pooling["pooling_mode_lasttoken"]: + elif pooling.get("pooling_mode_lasttoken"): pooling_type = gguf.PoolingType.LAST + elif (pooling_mode := pooling.get("pooling_mode")) in mode_mapping: + pooling_type = mode_mapping[pooling_mode] else: raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported") self.gguf_writer.add_pooling_type(pooling_type) @@ -2801,6 +2816,9 @@ def prepare_tensors(self): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", + "LlamaForCausalLMEagle3", + "Eagle3Speculator", + "Eagle3DraftModel", "IQuestCoderForCausalLM", "LlamaModel") class LlamaModel(TextModel): @@ -2815,7 +2833,60 @@ def __init__(self, *args, **kwargs): hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False) self.origin_hf_arch = hparams.get('architectures', [None])[0] + # detect EAGLE-3 llama checkpoint + if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1: + self.is_eagle3 = True + self.model_arch = gguf.MODEL_ARCH.EAGLE3 + logger.info("Detected EAGLE-3 draft model, switching to EAGLE3 architecture") + # Re-initialize tensor_map with EAGLE3 architecture + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + # Update gguf_writer architecture + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + self.gguf_writer.add_architecture() + if not hasattr(self, 'target_model_dir') or not self.target_model_dir: + raise ValueError( + "EAGLE3 model requires --target-model-dir to be specified. " + "Please provide the path to the target model directory to read config.json" + ) + # Read both EAGLE3 raw config and target model config + with open(self.dir_model / "config.json", 'r', encoding='utf-8') as f: + eagle3_raw_config = json.load(f) + with open(self.target_model_dir / "config.json", 'r', encoding='utf-8') as f: + target_config = json.load(f) + + # EAGLE3 extract_layers + target_num_layers = target_config["num_hidden_layers"] + extract_layers = [2, target_num_layers // 2, target_num_layers - 3] + logger.info(f"EAGLE3: extract_layers = {extract_layers} (target model has {target_num_layers} layers)") + self.gguf_writer.add_array(f"{self.gguf_writer.arch}.extract_layers", extract_layers) + + # EAGLE3 target_hidden_size: prefer EAGLE3 config, fallback to target config + if "target_hidden_size" in eagle3_raw_config and eagle3_raw_config["target_hidden_size"] is not None: + target_hidden_size = eagle3_raw_config["target_hidden_size"] + logger.info(f"EAGLE3: target_hidden_size = {target_hidden_size} (from EAGLE3 config)") + else: + target_hidden_size = target_config["hidden_size"] + logger.info(f"EAGLE3: target_hidden_size = {target_hidden_size} (from target model config)") + self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size) + + # Eagle3Speculator norm_before_residual specific handling + norm_before_residual = eagle3_raw_config.get("norm_before_residual", False) + logger.info(f"EAGLE3: norm_before_residual = {norm_before_residual} (from EAGLE3 config)") + self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual) + def set_vocab(self): + # For EAGLE-3 models, use tokenizer from target model if provided + if hasattr(self, 'is_eagle3') and self.is_eagle3: + if self.target_model_dir is None: + raise ValueError( + "EAGLE-3 draft model requires --target-model-dir to be specified. " + "Please provide the path to the target model directory containing the tokenizer." + ) + logger.info(f"EAGLE-3: Using tokenizer from target model: {self.target_model_dir}") + # Temporarily swap dir_model to load tokenizer from target model + original_dir_model = self.dir_model + self.dir_model = self.target_model_dir + if self.origin_hf_arch == "GlmasrModel": return self._set_vocab_glmedge() @@ -2859,6 +2930,10 @@ def set_vocab(self): if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + # Restore original dir_model for EAGLE-3 + if hasattr(self, 'is_eagle3') and self.is_eagle3: + self.dir_model = original_dir_model + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -2880,7 +2955,55 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None): _experts: list[dict[str, Tensor]] | None = None + def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: + tensors = super().index_tensors(remote_hf_model_id) + + # Handle Eagle3Speculator nested config + if "transformer_layer_config" in self.hparams: + self.hparams = {**self.hparams, **self.hparams["transformer_layer_config"]} + + # EAGLE-3 detection: check hparams directly (before self.is_eagle3 is set) + if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1: + logger.info("EAGLE-3: Renaming midlayer.* or layers.0.* to model.layers.0.*") + new_tensors = {} + # EAGLE-3: rename midlayer.* to model.layers.0.* for compatibility with llama model + for name, gen in tensors.items(): + if name.startswith("midlayer."): + new_name = "model.layers.0." + name[len("midlayer."):] + new_tensors[new_name] = gen + elif name.startswith("layers.0."): # layers.0.* -> model.layers.0.* (Eagle3Speculator format) + new_name = "model." + name + new_tensors[new_name] = gen + else: + new_tensors[name] = gen + return new_tensors + else: + return tensors + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # Eagle-3 llama checkpoint special handling + if hasattr(self, 'is_eagle3') and self.is_eagle3: + # Eagle-3 llama checkpoint special weights handling + # fc.weight: feature fusion layer + if name == "fc.weight": + yield (name, data_torch) + return + # d2t: draft to target vocabulary mapping + elif name == "d2t": + # Skip parent class processing (store for manual handling in prepare_tensors) + if not hasattr(self, '_eagle3_int_tensors'): + self._eagle3_int_tensors = {} + self._eagle3_int_tensors[name] = data_torch + return + # t2d: target to draft vocabulary mapping (not used, skip completely) + elif name == "t2d": + return + # hidden_norm: EAGLE-3 specific layer normalization + elif name == "model.layers.0.hidden_norm.weight": + yield ("blk.0.hidden_norm.weight", data_torch) + return + n_head = self.find_hparam(["n_heads", "num_attention_heads"]) n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"]) @@ -2950,6 +3073,17 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # EAGLE3: If no lm_head in draft model, load from target model + if hasattr(self, 'is_eagle3') and self.is_eagle3 and "lm_head.weight" not in self.model_tensors: + from safetensors import safe_open + for sf_file in self.target_model_dir.glob("*.safetensors"): + with safe_open(sf_file, framework="pt") as f: + if "lm_head.weight" in f.keys(): + lm_head = f.get_tensor("lm_head.weight") + logger.info(f"EAGLE3: No lm_head in draft model, loaded lm_head from {sf_file.name}, shape = {lm_head.shape}") + yield ("output.weight", lm_head) + break + if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters): if rope_params.get("rope_type", '').lower() == "llama3": base = rope_params.get("rope_theta", 10000.0) @@ -2980,8 +3114,26 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) def prepare_tensors(self): + # EAGLE-3: collect original dtypes BEFORE parent class converts them to F32 + eagle3_original_dtypes = {} + if hasattr(self, 'is_eagle3') and self.is_eagle3: + for name, data_torch in self.get_tensors(): + if name == "d2t": + eagle3_original_dtypes[name] = data_torch.dtype + super().prepare_tensors() + if hasattr(self, 'is_eagle3') and self.is_eagle3 and hasattr(self, '_eagle3_int_tensors'): + for name, data_torch in self._eagle3_int_tensors.items(): + old_dtype = eagle3_original_dtypes.get(name, data_torch.dtype) + # Keep as int64 to match original torch tensor dtype + data = data_torch.to(torch.int64).numpy() + data_qtype = gguf.GGMLQuantizationType.I64 + + shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" + logger.info(f"{name + ',':<30} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) + if self._experts is not None: # flatten `list[dict[str, Tensor]]` into `list[str]` experts = [k for d in self._experts for k in d.keys()] @@ -4735,6 +4887,47 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("DFlashDraftModel") +class DFlashModel(Qwen3Model): + model_arch = gguf.MODEL_ARCH.DFLASH + + def set_vocab(self): + if self.target_model_dir is None: + raise ValueError( + "DFlash draft model requires --target-model-dir to be specified. " + "Please provide the path to the target model directory containing the tokenizer." + ) + logger.info(f"DFLASH: Using tokenizer from target model: {self.target_model_dir}") + original_dir = self.dir_model + self.dir_model = self.target_model_dir + super().set_vocab() + self.dir_model = original_dir + + def set_gguf_parameters(self): + super().set_gguf_parameters() + block_size = self.hparams.get("block_size", 16) + self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.block_size", block_size) + dflash_config = self.hparams.get("dflash_config", {}) + target_layer_ids = dflash_config.get("target_layer_ids", []) + if target_layer_ids: + extract_layer_ids = [i + 1 for i in target_layer_ids] + self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layer_ids", extract_layer_ids) + mask_token_id = dflash_config.get("mask_token_id", None) + if mask_token_id is not None: + self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.mask_token_id", mask_token_id) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "fc.weight": + yield (name, data_torch) + return + if name == "hidden_norm.weight": + yield ("hidden_norm.weight", data_torch) + return + if not name.startswith("model."): + name = "model." + name + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Qwen3MoeForCausalLM") class Qwen3MoeModel(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3MOE @@ -7180,7 +7373,7 @@ def __init__(self, *args, **kwargs): with open(modules_file, encoding="utf-8") as modules_json_file: mods = json.load(modules_json_file) for mod in mods: - if mod["type"] == "sentence_transformers.models.Dense": + if mod["type"].endswith("Dense"): mod_path = mod["path"] # check if model.safetensors file for Dense layer exists model_tensors_file = self.dir_model / mod_path / "model.safetensors" @@ -10912,14 +11105,14 @@ def set_vocab(self): vocab_size = -(vocab_size // -pad_vocab) * pad_vocab self.hparams["vocab_size"] = vocab_size - assert max(tokenizer.vocab.values()) < vocab_size + assert max(tokenizer.vocab.values()) < vocab_size # ty: ignore[unresolved-attribute] tokpre = self.get_vocab_base_pre(tokenizer) - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} - added_vocab = tokenizer.get_added_vocab() + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} # ty: ignore[unresolved-attribute] + added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute] - added_tokens_decoder = tokenizer.added_tokens_decoder + added_tokens_decoder = tokenizer.added_tokens_decoder # ty: ignore[unresolved-attribute] for i in range(vocab_size): if i not in reverse_vocab: @@ -10930,7 +11123,7 @@ def set_vocab(self): if token in added_vocab: if not added_tokens_decoder[i].normalized: previous_token = token - token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) # ty: ignore[unresolved-attribute, invalid-assignment] if previous_token != token: logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") @@ -11847,7 +12040,7 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") -@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration") +@ModelBase.register("HunYuanDenseV1ForCausalLM") class HunYuanModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE @@ -11986,28 +12179,58 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("HunYuanVLForConditionalGeneration") -class HunyuanOCRVisionModel(MmprojModel): +class HunyuanVLVisionModel(MmprojModel): + # Handles both HunyuanOCR and HunyuanVL, which share the HF architecture name + # "HunYuanVLForConditionalGeneration" and the `vit.perceive.*` vision layout. + # Each variant maps to a different projector type in clip.cpp so image + # preprocessing follows the correct code path. + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.hparams_vision is not None - # HunyuanOCR uses max_image_size instead of image_size + # HunyuanOCR / HunyuanVL uses max_image_size instead of image_size if "image_size" not in self.hparams_vision: self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048) + @staticmethod + def is_ocr_variant(hparams: dict) -> bool: + """Return True for HunyuanOCR, False for HunyuanVL. + + The projector's output dim must equal the text model's hidden_size by + construction (that's what "projector" means). HunyuanOCR pairs a 1B text + backbone (hidden=1024); HunyuanVL pairs a 4B one (hidden=3072). So the + ViT -> LLM projection dim is a hard architectural signature, not a + magic number. + """ + vision_out = int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) + return vision_out == 1024 + def set_gguf_parameters(self): super().set_gguf_parameters() assert self.hparams_vision is not None - hparams = self.hparams_vision - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR) - self.gguf_writer.add_vision_use_gelu(True) - self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-5)) - self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2)) - self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"]) - self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"]) + vcfg = self.hparams_vision + + if self.is_ocr_variant(self.global_config): + # --- HunyuanOCR --- + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR) + self.gguf_writer.add_vision_use_gelu(True) + self.gguf_writer.add_vision_attention_layernorm_eps(vcfg.get("rms_norm_eps", 1e-5)) + self.gguf_writer.add_vision_spatial_merge_size(vcfg.get("spatial_merge_size", 2)) + self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"]) + self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"]) + return + + # --- HunyuanVL --- + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANVL) + self.gguf_writer.add_vision_use_gelu(str(vcfg["hidden_act"]).lower() == "gelu") + self.gguf_writer.add_vision_attention_layernorm_eps(float(vcfg["rms_norm_eps"])) + self.gguf_writer.add_vision_spatial_merge_size(int(vcfg["spatial_merge_size"])) + self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"])) + self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"])) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if not name.startswith("vit."): - return # skip text tensors + return # strip CLS token (row 0) from position embeddings so resize_position_embeddings works if "position_embedding" in name: data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd] @@ -12015,11 +12238,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter def tensor_force_quant(self, name, new_name, bid, n_dims): # force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal + # Both HunyuanOCR and HunyuanVL emit the ViT -> LLM projection as mm.0/mm.2. if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"): return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32 return super().tensor_force_quant(name, new_name, bid, n_dims) +@ModelBase.register("HunYuanVLForConditionalGeneration") +class HunyuanVLTextModel(HunYuanModel): + # The "HunYuanVLForConditionalGeneration" HF architecture covers both HunyuanOCR + # and HunyuanVL. HunyuanOCR reuses the HunYuan-Dense text backbone (standard RoPE), + # while HunyuanVL introduces a new LLM arch with XD-RoPE. Detect the variant from + # the config and pick the matching GGUF architecture. + model_arch = gguf.MODEL_ARCH.HUNYUAN_VL + + @staticmethod + def _is_ocr_config(hparams: dict) -> bool: + # OCR pairs a 1B text backbone (hidden=1024) with a ViT projector that + # outputs 1024-d; HunyuanVL uses 3072-d. Keep in sync with + # HunyuanVLVisionModel.is_ocr_variant. + return int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) == 1024 + + def __init__(self, dir_model: Path, *args, **kwargs): + raw_hparams = kwargs.get("hparams") or ModelBase.load_hparams(dir_model, is_mistral_format=False) + if self._is_ocr_config(raw_hparams): + self.model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE + else: + self.model_arch = gguf.MODEL_ARCH.HUNYUAN_VL + super().__init__(dir_model, *args, **kwargs) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # Only emit XD-RoPE metadata for the HunyuanVL backbone; HunyuanOCR uses + # the HunYuan-Dense arch which already handles standard rope in super(). + if self.model_arch != gguf.MODEL_ARCH.HUNYUAN_VL: + return + + if self.rope_parameters.get("rope_type") != "xdrope": + return + + # defaults for HunyuanVL. The C++ side later computes: + # freq_base = rope_theta * alpha ** (head_dim / (head_dim - 2)) + self.gguf_writer.add_rope_freq_base(float(self.rope_parameters["rope_theta"])) + self.gguf_writer.add_rope_scaling_alpha(float(self.rope_parameters["alpha"])) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(float(self.rope_parameters.get("factor", 1))) + + ctx_len = int(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(ctx_len) + self.gguf_writer.add_context_length(ctx_len) + + self.gguf_writer.add_rope_dimension_sections(list(self.rope_parameters["xdrope_section"])) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision tensors — they are written by HunyuanVLVisionModel + if name.startswith("vit."): + return + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("SmolLM3ForCausalLM") class SmolLM3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.SMOLLM3 @@ -13114,6 +13392,7 @@ class LazyTorchTensor(gguf.LazyBase): torch.float16: np.float16, torch.float32: np.float32, torch.uint8: np.uint8, + torch.int64: np.int64, } # only used when byteswapping data. Only correct size is needed @@ -13275,6 +13554,10 @@ def parse_args() -> argparse.Namespace: "--no-tensor-first-split", action="store_true", help="do not add tensors to the first split (disabled by default)" ) + parser.add_argument( + "--target-model-dir", type=str, default=None, + help="directory containing target model tokenizer (for EAGLE-3 draft models that don't have their own tokenizer)", + ) parser.add_argument( "--metadata", type=Path, help="Specify the path for an authorship metadata override file" @@ -13459,6 +13742,7 @@ def main() -> None: small_first_shard=args.no_tensor_first_split, remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, sentence_transformers_dense_modules=args.sentence_transformers_dense_modules, + target_model_dir=Path(args.target_model_dir) if args.target_model_dir else None, fuse_gate_up_exps=args.fuse_gate_up_exps ) diff --git a/docs/backend/OPENVINO.md b/docs/backend/OPENVINO.md index 96d0f672e30..c9c005a9981 100644 --- a/docs/backend/OPENVINO.md +++ b/docs/backend/OPENVINO.md @@ -244,7 +244,6 @@ build\ReleaseOV\bin\llama-cli.exe -m "C:\models\Llama-3.2-1B-Instruct-Q4_0.gguf" - `-fa 1` is required when running llama-bench with the OpenVINO backend. - `GGML_OPENVINO_STATEFUL_EXECUTION=1 GGML_OPENVINO_DEVICE=GPU ./llama-bench -fa 1` - `llama-server` with OpenVINO backend supports only one chat session/thread, when `GGML_OPENVINO_STATEFUL_EXECUTION=1` is enabled. -- For Intel GPU, NPU detection in containers, GPU, NPU user-space drivers/libraries must be present inside the image. We will include in a future PR. Until then, you can use this reference Dockerfile: [openvino.Dockerfile](https://github.com/ravi9/llama.cpp/blob/ov-docker-update/.devops/openvino.Dockerfile) > [!NOTE] > The OpenVINO backend is actively under development. Fixes are underway, and this document will continue to be updated as issues are resolved. @@ -274,8 +273,6 @@ docker build --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_p Run llama.cpp with OpenVINO backend Docker container. Save sample models in `~/models` as [shown above](#3-download-sample-model). It will be mounted to the container in the examples below. -> [!NOTE] -> Intel GPU, NPU detection in containers will be included in a future PR. Until then, you can use this reference Dockerfile: [openvino.Dockerfile](https://github.com/ravi9/llama.cpp/blob/ov-docker-update/.devops/openvino.Dockerfile). ```bash # Run Docker container diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index d52c61acb66..1b86b3d4acb 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -31,6 +31,8 @@ SYCL cross-platform capabilities enable support for other vendor GPUs as well. ## Recommended Release +### Windows + The following releases are verified and recommended: |Commit ID|Tag|Release|Verified Platform| Update date| @@ -39,6 +41,13 @@ The following releases are verified and recommended: |3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| |fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| +### Ubuntu 24.04 + +The release packages for Ubuntu 24.04 x64 (FP32/FP16) only include the binary files of the llama.cpp SYCL backend. They require the target machine to have pre-installed Intel GPU drivers and oneAPI packages that are the same version as the build package. To get the version and installation info, refer to release.yml: ubuntu-24-sycl -> Download & Install oneAPI. + +It is recommended to use them with Intel Docker. + +The packages for FP32 and FP16 would have different accuracy and performance on LLMs. Please choose it acording to the test result. ## News @@ -229,6 +238,7 @@ Upon a successful installation, SYCL is enabled for the available intel devices, |Verified release| |-| +|2025.3.3 | |2025.2.1| |2025.1| |2024.1| diff --git a/docs/backend/snapdragon/README.md b/docs/backend/snapdragon/README.md index e13fdfd05e7..2414eeaf6a4 100644 --- a/docs/backend/snapdragon/README.md +++ b/docs/backend/snapdragon/README.md @@ -249,18 +249,27 @@ build: 6a8cf8914 (6733) ``` - `GGML_HEXAGON_PROFILE=1` - Generates a host-side profile for the ggml-hexagon Ops. + Enables Op profiling: -- `GGML_HEXAGON_OPMASK=0x0` - Allows enabling specific stages of the processing pipeline: + - `1` Basic profile with per-op `usecs` and `cycles` counters + - `2` Extended profile with per-op `usecs`, `cycles` and default PMU counter data + - `0x1,...,0x8` Extended profile with per-op `usecs`, `cycles` and custom PMU counter data + + The logging output can be either saved into a file for post-processing or it can be piped directly into the post-processing tool to generate the report. + Examples: + + `GGML_HEXAGON_PROFILE=1 llama-completion ... |& ./scripts/snapdragon/ggml-hexagon-profile.py -` + +- `GGML_HEXAGON_OPSTAGE=0x0` + Allows enabling specific stages of the Op processing pipeline: - `0x1` Enable Op Queue (i.e., queuing Ops into NPU) - `0x2` Enable Op Compute (MUL_MAT, etc.) Examples: - `GGML_HEXAGON_OPMASK=0x1 llama-completion ...` - Ops are enqueued but NPU-side processing is stubbed out - `GGML_HEXAGON_OPMASK=0x3 llama-completion ...` - Full queuing and processing of Ops (default) + `GGML_HEXAGON_OPSTAGE=0x1 llama-completion ...` - Ops are enqueued to the NPU but dma & compute are disabled + `GGML_HEXAGON_OPSTAGE=0x3 llama-completion ...` - Full queuing and processing of Ops (default) - `GGML_HEXAGON_OPFILTER=regex` Allows filtering (disabling) Ops that match the regex pattern: diff --git a/examples/llama.android/lib/src/main/cpp/CMakeLists.txt b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt index 7862c61a3fc..20c9e3b2c1f 100644 --- a/examples/llama.android/lib/src/main/cpp/CMakeLists.txt +++ b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt @@ -51,6 +51,6 @@ target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE target_link_libraries(${CMAKE_PROJECT_NAME} llama - common + llama-common android log) diff --git a/examples/model-conversion/scripts/causal/convert-model.sh b/examples/model-conversion/scripts/causal/convert-model.sh index a5865f6acd3..4aa72206288 100755 --- a/examples/model-conversion/scripts/causal/convert-model.sh +++ b/examples/model-conversion/scripts/causal/convert-model.sh @@ -25,7 +25,11 @@ MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}" OUTPUT_DIR="${OUTPUT_DIR:-../../models}" TYPE="${OUTTYPE:-f16}" METADATA_OVERRIDE="${METADATA_OVERRIDE:-}" -CONVERTED_MODEL="${OUTPUT_DIR}/${MODEL_NAME}.gguf" +if [[ -n "$MMPROJ" ]]; then + CONVERTED_MODEL="${OUTPUT_DIR}/mmproj-${MODEL_NAME}.gguf" +else + CONVERTED_MODEL="${OUTPUT_DIR}/${MODEL_NAME}.gguf" +fi echo "Model path: ${MODEL_PATH}" echo "Model name: ${MODEL_NAME}" @@ -38,6 +42,7 @@ if [[ -n "$DEBUG" ]]; then else CMD_ARGS=("python") fi + CMD_ARGS+=("../../convert_hf_to_gguf.py" "--verbose") CMD_ARGS+=("${MODEL_PATH}") CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}") @@ -50,7 +55,3 @@ CMD_ARGS+=("--outtype" "${TYPE}") echo "" echo "The environment variable CONVERTED_MODEL can be set to this path using:" echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})" -if [[ -n "$MMPROJ" ]]; then - mmproj_file="${OUTPUT_DIR}/mmproj-$(basename "${CONVERTED_MODEL}")" - echo "The mmproj model was created in $(realpath "$mmproj_file")" -fi diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a03dbce887f..804a16623a4 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -4,12 +4,29 @@ #include "speculative.h" #include "log.h" #include "llama.h" +#include "chat.h" #include #include #include +#include #include #include +#include + +struct spec_checkpoint { + int64_t n_tokens = 0; + + std::vector data; + + size_t size() const { + return data.size(); + } + + bool empty() const { + return data.empty(); + } +}; int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); @@ -46,6 +63,14 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt->model(); ctx_tgt = llama_init_tgt->context(); + // check if the context supports partial sequence removal + const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); + const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + + if (use_ckpt) { + LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); + } + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model @@ -79,13 +104,63 @@ int main(int argc, char ** argv) { return 1; } + params.speculative.model_tgt = model_tgt; params.speculative.model_dft = model_dft.get(); params.speculative.cparams_dft = common_context_params_to_llama(params_dft); + + if (params.speculative.eagle3) { + llama_set_eagle3(ctx_tgt, model_dft.get()); + } + if (params.speculative.dflash) { + llama_set_dflash(ctx_tgt, model_dft.get()); + } } + // Apply chat template for EAGLE3 / DFlash if available which can increase the acceptance rate + std::string prompt = params.prompt; + if (params.speculative.eagle3 || params.speculative.dflash) { + auto chat_templates = common_chat_templates_init(model_tgt, params.chat_template); + if (common_chat_templates_was_explicit(chat_templates.get())) { + std::vector chat_msgs; + common_chat_msg user_msg; + user_msg.role = "user"; + user_msg.content = params.prompt; + chat_msgs.push_back(user_msg); + + common_chat_templates_inputs inputs; + inputs.messages = chat_msgs; + inputs.add_generation_prompt = true; + // Disable thinking mode can improve accept rate + if (const char * nt = std::getenv("LLAMA_SPEC_NO_THINK"); nt && std::string(nt) != "0") { + // Qwen3 / 3.5 + inputs.enable_thinking = false; + // gpt-oss + inputs.chat_template_kwargs["reasoning_effort"] = "\"low\""; + } + prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt; + LOG_INF("%s: %s chat template applied\n", __func__, params.speculative.eagle3 ? "EAGLE3" : "DFlash"); + } + } + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + // used to determine end of generation + bool has_eos = false; + + // ================================================ + // everything until here is standard initialization + // the relevant stuff for speculative decoding starts here + + const auto t_enc_start = ggml_time_us(); + + // target model sampling context + common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling)); + // Tokenize the prompt std::vector inp; - inp = common_tokenize(ctx_tgt, params.prompt, true, true); + inp = common_tokenize(ctx_tgt, prompt, true, true); if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) { LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); @@ -105,33 +180,39 @@ int main(int argc, char ** argv) { LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); } - int n_predict = 0; - int n_drafted = 0; - int n_accept = 0; - // used to determine end of generation - bool has_eos = false; + // eval the prompt + llama_token id_last; + llama_tokens prompt_tgt; + int n_past; - // ================================================ - // everything until here is standard initialization - // the relevant stuff for speculative decoding starts here + // TODO: simplify + if (params.speculative.eagle3 || params.speculative.dflash) { + // Target model decodes full prompt and sample first token and intermediate features are extracted + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size())); - const auto t_enc_start = ggml_time_us(); + id_last = common_sampler_sample(smpl.get(), ctx_tgt, -1); + common_sampler_accept(smpl.get(), id_last, true); + LOG("%s", common_token_to_piece(ctx_tgt, id_last).c_str()); + n_predict++; - // target model sampling context - struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); + // all tokens currently in the target context + prompt_tgt.assign(inp.begin(), inp.end()); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); - // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + n_past = inp.size(); + } else { + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); - // note: keep the last token separate! - llama_token id_last = inp.back(); + // note: keep the last token separate! + id_last = inp.back(); - // all tokens currently in the target context - llama_tokens prompt_tgt(inp.begin(), inp.end() - 1); - prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); + // all tokens currently in the target context + prompt_tgt.assign(inp.begin(), inp.end() - 1); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); - int n_past = inp.size() - 1; + n_past = inp.size() - 1; + } // init the speculator const auto & params_spec = params.speculative; @@ -142,21 +223,61 @@ int main(int argc, char ** argv) { llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + size_t n_draft = 0; + + llama_tokens draft; + spec_checkpoint spec_ckpt; + const auto t_enc_end = ggml_time_us(); const auto t_dec_start = ggml_time_us(); while (true) { - // optionally, generate draft tokens that can be appended to the target batch + // generate or reuse draft tokens // // this is the most important part of the speculation. the more probable tokens that are provided here // the better the performance will be. in theory, this computation can be performed asynchronously and even // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + if (draft.empty()) { + // generate a new draft + draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + + if ((int) draft.size() > params_spec.n_max) { + LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max); + draft.resize(params_spec.n_max); + } + + if ((int) draft.size() < params_spec.n_min) { + LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min); + draft.clear(); + } + + // save the original draft size + n_draft = draft.size(); + + // save a checkpoint of the target context before evaluating the draft + // this allows us to restore the state if partial draft acceptance occurs + if (!draft.empty() && use_ckpt) { + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + spec_ckpt.data.resize(ckpt_size); - //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); + const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == ckpt_size); + + spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); + LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n", + spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + } + } else { + // we have a previous (partial) draft to reuse from checkpoint restoration + if (use_ckpt) { + GGML_ASSERT(!spec_ckpt.empty()); + } + } + + GGML_ASSERT(n_draft > 0); // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); @@ -178,6 +299,12 @@ int main(int argc, char ** argv) { llama_decode(ctx_tgt, batch_tgt); } + // only save the sampler sampler state if we use checkpoints + common_sampler_ptr smpl_save; + if (use_ckpt) { + smpl_save.reset(common_sampler_clone(smpl.get())); + } + // sample from the full target batch and return the accepted tokens based on the target sampler // // for each token to be accepted, the sampler would have to sample that same token @@ -185,14 +312,38 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft); //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + // check for partial draft acceptance: + // if the context doesn't support partial sequence removal, restore the checkpoint + // and make the accepted tokens the new partial draft for the next iteration + if (use_ckpt && ids.size() - 1 < draft.size()) { + LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); + + draft = std::move(ids); + + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == spec_ckpt.size()); + + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + + prompt_tgt.resize(spec_ckpt.n_tokens); + smpl = std::move(smpl_save); + + n_past = (int) prompt_tgt.size(); + + continue; + } + + common_speculative_accept(spec, ids.size() - 1); + + // full acceptance: consume the draft and commit accepted tokens n_past += ids.size() - 1; - n_drafted += draft.size(); // note: we ignore the discarded small drafts + n_drafted += n_draft; // note: we ignore the discarded small drafts n_accept += ids.size() - 1; n_predict += ids.size(); @@ -222,6 +373,9 @@ int main(int argc, char ** argv) { LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last); + // clear the draft since it has been consumed + draft.clear(); + { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); @@ -254,11 +408,10 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("target:\n\n"); - common_perf_print(ctx_tgt, smpl); + common_perf_print(ctx_tgt, smpl.get()); llama_batch_free(batch_tgt); - common_sampler_free(smpl); common_speculative_free(spec); llama_backend_free(); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6b65ecd6e5c..b9f7deb150d 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,17 +1,11 @@ cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. -# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html -# MSVC is not a valid assembler for the ASM language. -# Set to NEW to avoid a warning on CMake 4.1+ with MSVC. -if (POLICY CMP0194) - cmake_policy(SET CMP0194 NEW) -endif() project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 11) +set(GGML_VERSION_MINOR 10) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") @@ -219,7 +213,7 @@ set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) -option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph" ON) option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 48fbe208d90..52754e1b9d6 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -473,7 +473,7 @@ target_link_libraries(ggml-base PRIVATE Threads::Threads) find_library(MATH_LIBRARY m) if (MATH_LIBRARY) if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) - target_link_libraries(ggml-base PRIVATE m) + target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) endif() endif() diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 1ee3eeb4d96..6d22f3421b1 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1133,7 +1133,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { - GGML_ASSERT(ne[split_dim] != 0 && tensor->ne[split_dim] != 0); + GGML_ASSERT(tensor->ne[split_dim] != 0); const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); @@ -1170,6 +1170,28 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer simple_tensors.push_back(t_ij); } + + // If one of the sources has a zero-sized slice, disable the computation: + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) { + continue; + } + + const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) { + continue; + } + for (size_t j = 0; j < n_simple_bufs; j++) { + int64_t ne_sum = 0; + for (size_t s = 0; s < split_state_src.n_segments; s++) { + ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + } + if (ne_sum == 0) { + simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + buf_ctx->simple_tensors[tensor] = simple_tensors; return GGML_STATUS_SUCCESS; @@ -1270,7 +1292,45 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co GGML_ASSERT(ggml_is_contiguous(tensor)); const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - GGML_ASSERT(split_state.n_segments == 1); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[1] == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[2] == size); + return; + } switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: @@ -1404,26 +1464,32 @@ struct ggml_backend_meta_context { struct backend_config { ggml_backend_t backend; - std::vector cgraphs; - std::vector nodes; - ggml_backend_buffer_ptr buf; + std::vector cgraphs; + std::vector nodes; + std::vector bufs; - backend_config(ggml_backend_t backend) : backend(backend) {} + backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) { + bufs.resize(n_reduce_steps); + } }; std::string name; std::vector backend_configs; ggml_context_ptr ctx; std::vector cgraphs_aux; std::vector nodes_aux; + size_t n_reduce_steps; int max_nnodes = 0; size_t max_tmp_size = 0; size_t max_subgraphs = 0; + size_t n_subgraphs = 0; + uint64_t uid = 0; void * comm_ctx = nullptr; ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + n_reduce_steps = std::ceil(std::log2(n_devs)); name = "Meta("; std::vector simple_backends; backend_configs.reserve(n_devs); @@ -1435,7 +1501,7 @@ struct ggml_backend_meta_context { } name += ggml_backend_dev_name(simple_dev); simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); - backend_configs.emplace_back(simple_backends.back()); + backend_configs.emplace_back(simple_backends.back(), n_reduce_steps); } name += ")"; @@ -1465,10 +1531,6 @@ struct ggml_backend_meta_context { ggml_backend_free(bc.backend); } } - - size_t n_reduce_steps() const { - return std::ceil(std::log2(backend_configs.size())); - } }; static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { @@ -1578,6 +1640,9 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, const size_t n_backends = ggml_backend_meta_n_backends(backend); ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend. + const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid); + bool max_nnodes_raised = false; if (cgraph->n_nodes > backend_ctx->max_nnodes) { for (size_t j = 0; j < n_backends; j++) { @@ -1587,173 +1652,216 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } backend_ctx->max_nnodes = cgraph->n_nodes; max_nnodes_raised = true; + assert(needs_rebuild); } - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. - // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. - bcj.nodes[i] = node; - continue; + if (needs_rebuild) { + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); } - bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); - GGML_ASSERT(bcj.nodes[i]); } - } - size_t n_subgraphs = 0; - size_t max_tmp_size = 0; - { - // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: - auto get_i_delayed = [&](const int i) -> int { - int id = i; // i_delayed - int idr = i; // i_delayed return, last safe return value - - ggml_tensor * node = cgraph->nodes[id]; - int32_t n_used = ggml_node_get_use_count(cgraph, id); - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_ADD_ID && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && - ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value + + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + + // Skip MIRRORED nodes that don't consume node + auto skip_unrelated = [&]() { + while (id + 1 < cgraph->n_nodes) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + break; + } + bool safe = true; + for (int s = 0; s < GGML_MAX_SRC; s++) { + if (next->src[s] == nullptr) { + continue; + } + if (next->src[s] == node) { + safe = false; + break; + } + if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + safe = false; + break; + } + } + if (!safe) { + break; + } + id++; + } + }; + + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + // Chain of MULs with MIRRORED src[1] + while (true) { + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } else { + break; + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_MUL && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - - if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; return idr; - } - for (int32_t k = 0; k < n_used; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || - next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || - ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + }; + + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; } - id++; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); } - id++; - } - for (int32_t k = 0; k < n_used - 2; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; } - id++; - } - idr = id; - return idr; - }; - int i_start = 0; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - continue; - } - const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); - if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { - max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); - } - const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; - if (!new_subgraph) { - continue; + i = get_i_delayed(i); + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; } + GGML_ASSERT(i_start == cgraph->n_nodes); + } - i = get_i_delayed(i); + backend_ctx->uid = cgraph->uid; + backend_ctx->n_subgraphs = n_subgraphs; + if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.cgraphs[n_subgraphs].offset = i_start; + for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) { + bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + } + backend_ctx->max_tmp_size = max_tmp_size; + } + + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device + const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); } - n_subgraphs++; - i_start = i + 1; - } - GGML_ASSERT(i_start == cgraph->n_nodes); - } - - if (max_tmp_size > backend_ctx->max_tmp_size) { - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); } - backend_ctx->max_tmp_size = max_tmp_size; - } - - if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { - backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); - const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); - const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step - const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step - const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); - const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); - const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); - ggml_init_params params = { - /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - backend_ctx->ctx.reset(ggml_init(params)); for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i = 0; i < n_subgraphs; i++) { - bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); - } - } - backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { - backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); - } - backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { - backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); - } - } - - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { - ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; - const size_t i_node_start = bcj.cgraphs[i_graph].offset; - const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; - cgraph_ij->n_nodes = i_node_stop - i_node_start; - ggml_hash_set_reset(&cgraph_ij->visited_hash_set); - for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { - ggml_tensor * node_ij = bcj.nodes[i_node]; - cgraph_ij->nodes[i_node - i_node_start] = node_ij; - const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); - const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); - cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + cgraph_ij->uid = ggml_graph_next_uid(); } } } @@ -1761,11 +1869,6 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, size_t iga = 0; // i graph aux size_t ina = 0; // i node aux - // FIXME usage_counts - auto get_cgraph_aux = [&]() -> ggml_cgraph * { - ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; - return ret; - }; auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; memset(ret, 0, sizeof(ggml_tensor)); @@ -1777,75 +1880,110 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } return ret; }; + auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf]; + if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) { + buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size)); + } + tensor->buffer = buf_ptr.get(); + tensor->data = ggml_backend_buffer_get_base(buf_ptr.get()); + }; + // FIXME usage_counts + auto get_cgraph_aux = [&]() -> ggml_cgraph * { + ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; + return ret; + }; // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: auto allreduce_fallback = [&](size_t i) -> ggml_status { std::vector step_cgraphs(n_backends, nullptr); - for (size_t offset_j = 1; offset_j < n_backends; offset_j *= 2) { + // Zero out nodes that were disabled due to having a zero-sized slice: + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1]; + if (node->flags & GGML_TENSOR_FLAG_COMPUTE) { + continue; + } + ggml_tensor * node_zero = get_node_aux(node); + node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN + node_zero->src[0] = node; + ggml_set_op_params_f32(node_zero, 0, 0.0f); + node_zero->data = node->data; + node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; + + step_cgraphs[j] = get_cgraph_aux(); + step_cgraphs[j]->nodes[0] = node_zero; + step_cgraphs[j]->n_nodes = 1; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) { + assert(step_cgraphs[j_dst] == nullptr); + auto & bcj_src = backend_ctx->backend_configs[j_src]; + auto & bcj_dst = backend_ctx->backend_configs[j_dst]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + GGML_ASSERT(ggml_is_contiguous(node_src)); + GGML_ASSERT(ggml_is_contiguous(node_dst)); + + ggml_tensor * node_tmp = get_node_aux(node_dst); + set_tmp_data(node_tmp, j_dst, i_buf); + + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp); + + ggml_tensor * node_red = get_node_aux(node_dst); + node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src; + node_red->view_offs = node_dst->view_offs; + node_red->op = GGML_OP_ADD; + node_red->src[0] = node_dst; + node_red->src[1] = node_tmp; + node_red->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red); + + ggml_cgraph * cgraph_aux = get_cgraph_aux(); + cgraph_aux->nodes[0] = node_red; + cgraph_aux->n_nodes = 1; + step_cgraphs[j_dst] = cgraph_aux; + }; + + size_t offset_j = n_backends/2; + while ((offset_j & (offset_j - 1)) != 0) { + offset_j--; + } + const size_t offset_j_max = offset_j; + size_t i_buf = 0; + + // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction: + for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) { + const size_t j_dst = j_src - 2*offset_j_max; + push_data(j_src, j_dst, i_buf); + const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + i_buf = 1; + } + + // Butterfly reduction: + for (; offset_j >= 1; offset_j /= 2) { std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); - for (size_t j = 0; j < n_backends; j++) { + for (size_t j = 0; j < 2*offset_j_max; j++) { const size_t j_other = j ^ offset_j; - if (j_other > j) { + if (j_other >= n_backends) { continue; } - - auto & bcj1 = backend_ctx->backend_configs[j]; - auto & bcj2 = backend_ctx->backend_configs[j_other]; - - ggml_tensor * node1 = bcj1.cgraphs[i].cgraph_main->nodes[bcj1.cgraphs[i].cgraph_main->n_nodes - 1]; - ggml_tensor * node2 = bcj2.cgraphs[i].cgraph_main->nodes[bcj2.cgraphs[i].cgraph_main->n_nodes - 1]; - GGML_ASSERT(ggml_is_contiguous(node1)); - GGML_ASSERT(ggml_is_contiguous(node2)); - - // Tmp tensors to receive P2P copies - ggml_tensor * node_tmp_1 = get_node_aux(node1); - node_tmp_1->buffer = bcj1.buf.get(); - node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.buf.get()); - - ggml_tensor * node_tmp_2 = get_node_aux(node2); - node_tmp_2->buffer = bcj2.buf.get(); - node_tmp_2->data = ggml_backend_buffer_get_base(bcj2.buf.get()); - - // 2 P2P copies: exchange full buffers - ggml_backend_tensor_copy_async(bcj1.backend, bcj2.backend, node1, node_tmp_2); - ggml_backend_tensor_copy_async(bcj2.backend, bcj1.backend, node2, node_tmp_1); - - // Local ADD: node1 += tmp1 (in-place via view) - ggml_tensor * node_red_1 = get_node_aux(node1); - node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src; - node_red_1->view_offs = node1->view_offs; - node_red_1->op = GGML_OP_ADD; - node_red_1->src[0] = node1; - node_red_1->src[1] = node_tmp_1; - node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE; - ggml_backend_view_init(node_red_1); - - // Local ADD: node2 += tmp2 (in-place via view) - ggml_tensor * node_red_2 = get_node_aux(node2); - node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src; - node_red_2->view_offs = node2->view_offs; - node_red_2->op = GGML_OP_ADD; - node_red_2->src[0] = node2; - node_red_2->src[1] = node_tmp_2; - node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE; - ggml_backend_view_init(node_red_2); - - // Build 1-node cgraphs for the ADD ops - ggml_cgraph * cgraph_aux_1 = get_cgraph_aux(); - cgraph_aux_1->nodes[0] = node_red_1; - cgraph_aux_1->n_nodes = 1; - step_cgraphs[j] = cgraph_aux_1; - - ggml_cgraph * cgraph_aux_2 = get_cgraph_aux(); - cgraph_aux_2->nodes[0] = node_red_2; - cgraph_aux_2->n_nodes = 1; - step_cgraphs[j_other] = cgraph_aux_2; + push_data(j, j_other, i_buf); } - // Execute local ADDs for this step - for (size_t j = 0; j < n_backends; j++) { + for (size_t j = 0; j < 2*offset_j_max; j++) { if (step_cgraphs[j] == nullptr) { continue; } @@ -1855,12 +1993,25 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, return status; } } + i_buf++; } + assert(i_buf == backend_ctx->n_reduce_steps); + + // If n_backends is not a power of 2, copy back the reduced tensors to the excess: + for (size_t j = 2*offset_j_max; j < n_backends; j++) { + auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max]; + auto & bcj_dst = backend_ctx->backend_configs[j]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst); + } + return GGML_STATUS_SUCCESS; }; - for (size_t i = 0; i < n_subgraphs; i++) { + for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); @@ -1869,7 +2020,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } } - if (n_backends > 1 && i < n_subgraphs - 1) { + if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) { bool backend_allreduce_success = false; if (backend_ctx->comm_ctx) { std::vector nodes; diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index c589a213e9d..595ded09f03 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -83,7 +83,6 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 64d811fafe7..fe621332970 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -151,8 +151,6 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const block_q1_0 * GGML_RESTRICT x = vx; const block_q8_0 * GGML_RESTRICT y = vy; - float sumf = 0.0f; - #if defined(__ARM_NEON) float32x4_t sumv = vdupq_n_f32(0.0f); @@ -212,31 +210,13 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi } } - sumf = vaddvq_f32(sumv); + *s = vaddvq_f32(sumv); #else - // Scalar fallback - for (int i = 0; i < nb; i++) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - - // Process 4 Q8_0 blocks - for (int k = 0; k < 4; k++) { - const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); - - int sumi = 0; - for (int j = 0; j < QK8_0; j++) { - const int bit_index = k * QK8_0 + j; - const int byte_index = bit_index / 8; - const int bit_offset = bit_index % 8; - - const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; - sumi += xi * y[i*4 + k].qs[j]; - } - sumf += d0 * d1 * sumi; - } - } + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif - - *s = sumf; } diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 74d699f633d..0a3e071e57c 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -274,6 +274,18 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const } #endif #elif defined(__SSSE3__) +static inline __m128i bytes_from_bits_16(const uint8_t * x) { + uint16_t x16; + memcpy(&x16, x, sizeof(uint16_t)); + + const __m128i shuf_mask = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + __m128i bytes = _mm_shuffle_epi8(_mm_set1_epi16((short) x16), shuf_mask); + const __m128i bit_mask = _mm_set_epi64x(0x7fbfdfeff7fbfdfe, 0x7fbfdfeff7fbfdfe); + bytes = _mm_or_si128(bytes, bit_mask); + + return _mm_cmpeq_epi8(bytes, _mm_set1_epi64x(-1)); +} + // horizontally add 4x4 floats static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { __m128 res_0 =_mm_hadd_ps(a, b); @@ -540,6 +552,152 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block; + { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32)); + } + for (int K = 1; K < 4; ++K) { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); + } + acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block; + { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[0]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), q); + } + for(int K = 1; K < 4; ++K) { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); + } +#undef Q1_AVX_BLOCK + + acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block)); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d)); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + +#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \ + { \ + const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \ + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \ + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \ + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \ + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \ + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \ + const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \ + const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \ + const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \ + (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \ + } + Q1_SSSE3_BLOCK(0, 0, acc_0) + Q1_SSSE3_BLOCK(4, 1, acc_1) + Q1_SSSE3_BLOCK(8, 2, acc_2) + Q1_SSSE3_BLOCK(12, 3, acc_3) +#undef Q1_SSSE3_BLOCK + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index f66127c2290..e5f9a4083f9 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -137,22 +137,28 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c float sumf = 0.0; for (int i = 0; i < nb; i++) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); float sumi = 0.0f; for (int k = 0; k < 4; k++) { - const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); - + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); int sumi_block = 0; - for (int j = 0; j < QK8_0; j++) { - const int bit_index = k * QK8_0 + j; - const int byte_index = bit_index / 8; - const int bit_offset = bit_index % 8; - - const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; - sumi_block += xi * y[i*4 + k].qs[j]; + const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 4; ++b, qy += 8) { + const unsigned mask = bits[b]; + sumi_block += ((mask & 0x01) ? qy[0] : -qy[0]) + + ((mask & 0x02) ? qy[1] : -qy[1]) + + ((mask & 0x04) ? qy[2] : -qy[2]) + + ((mask & 0x08) ? qy[3] : -qy[3]) + + ((mask & 0x10) ? qy[4] : -qy[4]) + + ((mask & 0x20) ? qy[5] : -qy[5]) + + ((mask & 0x40) ? qy[6] : -qy[6]) + + ((mask & 0x80) ? qy[7] : -qy[7]); } sumi += d1 * sumi_block; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 66ed02d2923..3aec1742ee1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) -#if defined(TURING_MMA_AVAILABLE) -#define LDMATRIX_TRANS_AVAILABLE -#endif // defined(TURING_MMA_AVAILABLE) - static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); @@ -1187,6 +1183,7 @@ struct ggml_cuda_graph { bool disable_due_to_gpu_arch = false; bool warmup_complete = false; uint64_t uid = 0; + int64_t last_used_time = 0; struct node_properties { ggml_tensor node; void * node_src_data_ptrs[GGML_MAX_SRC]; @@ -1368,12 +1365,28 @@ struct ggml_backend_cuda_context { // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) std::unordered_map> cuda_graphs; + int64_t last_graph_eviction_sweep = 0; + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + const int64_t time_now = ggml_time_us(); + + // sweep every 5s, evicting cuda graphs unused for >=10s + if (time_now - last_graph_eviction_sweep >= 5'000'000) { + last_graph_eviction_sweep = time_now; + for (auto it = cuda_graphs.begin(); it != cuda_graphs.end(); ) { + if (time_now - it->second->last_used_time >= 10'000'000) { + it = cuda_graphs.erase(it); + } else { + ++it; + } + } + } + auto it = cuda_graphs.find(first_node_ptr); if (it == cuda_graphs.end()) { - cuda_graphs[first_node_ptr] = std::make_unique(); - return cuda_graphs[first_node_ptr].get(); + it = cuda_graphs.emplace(first_node_ptr, std::make_unique()).first; } + it->second->last_used_time = time_now; return it->second.get(); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b613ae61fb8..e185449d491 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + // The minimum granularity is 16 bytes. + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; if constexpr (use_cp_async) { + static_assert(warp_size == 32, "bad warp_size"); static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); @@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // 6: max 1*16= 16 bytes, 8 half ggml_cuda_unroll<6>{}(load); } else { - // TODO use ggml_cuda_memcpy_1 + const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}; auto load = [&] __device__ (const int n) { - const int stride_k = warp_size >> n; - const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_k = 32 >> n; + const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { @@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4, + !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero); } } }; - // 1: max 32* 4=128 bytes, 64 half - // 2: max 16* 4= 64 bytes, 32 half - // 3: max 8* 4= 32 bytes, 16 half - // 4: max 4* 4= 16 bytes, 8 half - ggml_cuda_unroll<4>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } } @@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } -#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - T_A_VKQ A_identity; - make_identity_mat(A_identity); -#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { @@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. -#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); -#elif defined(AMD_MFMA_AVAILABLE) - // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. - // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. - // Load with transposed addressing: 4 strided half loads. - { - const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; - const half * xs0_h = (const half *) xs0; - const int stride_h = stride_tile_V * 2; // stride in half units - half * A_h = (half *) A.x; -#pragma unroll - for (int l = 0; l < 4; ++l) { - A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; - } - } -#else - // TODO: Try to transpose tile_V when loading gmem to smem. - // Use mma to transpose T_A_VKQ for RDNA. - T_A_VKQ A_trans; - load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - mma(A, A_trans, A_identity); -#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de579d2ed50..1c2c3b4ac69 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -368,15 +368,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } ~ggml_cuda_pool_leg() { + clear_pool(); + GGML_ASSERT(pool_size == 0); + } + + void clear_pool() { ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { CUDA_CHECK(cudaFree(b.ptr)); pool_size -= b.size; + b.ptr = nullptr; + b.size = 0; } } - GGML_ASSERT(pool_size == 0); } void * alloc(size_t size, size_t * actual_size) override { @@ -421,7 +427,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaErrorMemoryAllocation) { + (void)cudaGetLastError(); + const size_t cached_bytes = pool_size; + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", + device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + CUDA_CHECK(cudaDeviceSynchronize()); + clear_pool(); + err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaSuccess) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device); + } + } + CUDA_CHECK(err); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC @@ -1203,6 +1222,13 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg // For small tensors, simply reduce them as FP32. // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + for (size_t i = 0; i < n_backends; ++i) { + if ((tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + ggml_cuda_set_device(cuda_ctx->device); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, ggml_nbytes(tensors[i]), cuda_ctx->stream())); + } + } NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < n_backends; ++i) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; @@ -1224,7 +1250,11 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg tmp[i].alloc(ne); ggml_cuda_set_device(cuda_ctx->device); - to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + if (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) { + to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + } else { + CUDA_CHECK(cudaMemsetAsync(tmp[i].get(), 0, ne * sizeof(nv_bfloat16), cuda_ctx->stream())); + } CUDA_CHECK(cudaGetLastError()); } @@ -3562,6 +3592,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * sqr = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != sqr->type) { + return false; + } + + if (!ggml_is_contiguous(unary->src[0])) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -4070,6 +4124,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { i += 2; ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c91dd2d9ad6..b0f674635f1 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -86,17 +86,12 @@ namespace ggml_cuda_mma { // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - static constexpr bool is_i_major(const data_layout dl) { - return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED; - } - static constexpr __device__ data_layout get_input_data_layout() { -#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) } template @@ -113,7 +108,6 @@ namespace ggml_cuda_mma { T x[ne] = {0}; static constexpr __device__ bool supported() { - if (I == 64 && J == 2) return true; if (I == 16 && J == 8) return true; if (I == 32 && J == 4) return true; if (I == 16 && J == 16) return true; @@ -122,7 +116,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + if constexpr (I == 16 && J == 4) { return threadIdx.x % 16; } else if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; @@ -139,8 +133,8 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> - return (2 * ((threadIdx.x / 16) % 2) + l); + if constexpr (I == 16 && J == 4) { + return threadIdx.x / 16; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 4) { @@ -154,7 +148,7 @@ namespace ggml_cuda_mma { return -1; } } -#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#elif defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -283,7 +277,7 @@ namespace ggml_cuda_mma { static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -407,7 +401,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) }; template @@ -701,57 +695,12 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) - static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { -#if defined(RDNA4) - const int row = t.get_i(0); - const int left_right = t.get_j(0) / 4; - const int up_down = row / 8; - const int idx = row % 8; - reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; -#else - GGML_UNUSED_VARS(t); - NO_DEVICE_CODE; -#endif // defined(RDNA4) - } - template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { -#if defined(AMD_MFMA_AVAILABLE) - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } -#elif defined(AMD_WMMA_AVAILABLE) - // All wmma layout has contiguous data when i-major. - if constexpr (is_i_major(dl)) { - // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() - constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); - if constexpr (sizeof(t.x) > aligned_copy_bytes) { - static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); - constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; -#pragma unroll - for (int i = 0; i < aligned_copy_count; ++i) { - ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } - } else { -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } -#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -764,26 +713,37 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(t, xs0, stride); + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2); +#else + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 4, "bad ne"); + ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -796,19 +756,26 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); -#else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#if 1 - // TODO: more generic handling - static_assert(sizeof(T) == 4, "bad type size"); +#elif defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 32, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4); #else - load_generic(t, xs0, stride); -#endif // 1 + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -827,23 +794,30 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void load_ldmatrix( tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE - int * xi = (int * ) t.x; + int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); +#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; @@ -1218,73 +1192,27 @@ namespace ggml_cuda_mma { using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) - #elif defined(AMD_WMMA_AVAILABLE) - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; - #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); - + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true); #endif // RDNA4 - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1297,19 +1225,10 @@ namespace ggml_cuda_mma { using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) #else @@ -1329,7 +1248,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1344,12 +1263,12 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } static __device__ __forceinline__ void mma( tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1364,41 +1283,35 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA4) || defined(CDNA3) + const int64_t xA = uint32_t(A.x[0]); + const int64_t xB = uint32_t(B.x[0]); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) +#elif defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false); #endif // RDNA4 #else GGML_UNUSED(D); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 28b662df925..b1a319de9be 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -104,7 +104,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -114,9 +114,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -1054,13 +1054,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); @@ -1295,13 +1295,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -1435,57 +1435,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1510,13 +1460,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -1742,74 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1834,13 +1717,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; @@ -2573,59 +2456,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -2651,13 +2482,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 4ad30fa1f35..2aeba26f414 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) { return x * x; } +static __device__ __forceinline__ float op_relu_sqr(float x) { + const float r = fmaxf(x, 0.0f); + return r * r; +} + static __device__ __forceinline__ float op_sqrt(float x) { return sqrtf(x); } @@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary GGML_ABORT("Unsupported unary op for fused unary+mul"); } } + +/* fused relu + sqr */ + +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) { + const ggml_tensor * src = relu_node->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + GGML_ASSERT(src->type == sqr_node->type); + + const int k = ggml_nelements(src); + if (src->type == GGML_TYPE_F16) { + unary_cuda((const half *)src->data, (half *)sqr_node->data, k, stream); + } else { + unary_cuda((const float *)src->data, (float *)sqr_node->data, k, stream); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index f1dd2183a6c..81ed873ecc3 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -91,6 +91,8 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 898fec31e36..78ca364d38f 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -33,7 +33,6 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} -#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) @@ -59,6 +58,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t +#define cudaErrorMemoryAllocation hipErrorOutOfMemory #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1abb8acfd4b..8aa056e9174 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -42,6 +42,7 @@ #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t +#define cudaErrorMemoryAllocation musaErrorMemoryAllocation #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d68b80048f..0d9b5e289bb 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -12,9 +12,12 @@ #include #include #include +#include +#include #include #include #include +#include #ifdef _WIN32 # include @@ -41,18 +44,26 @@ #include "htp_iface.h" #include "htp-drv.h" +using intvec = std::vector; +using uintvec = std::vector; +using u32vec = std::vector; + static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all static int opt_arch = 0; // autodetect static int opt_etm = 0; static int opt_verbose = 0; -static int opt_profile = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) static int opt_hostbuf = 1; // hostbuf ON by default static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +// Default PMU events, if profiling with PMU (mode=2) is enabled +// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html +// https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html +static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }; + // Enable all stages by default -static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_COMPUTE; -static int opt_opsync = 0; // synchronous ops +static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches static std::regex* opt_opfilter = NULL; // regex of ops to not claim @@ -104,19 +115,26 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct } static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, - uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) { + uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, - op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); + + char pmu_str[256] = ""; + if (opt_profile > 1) { + static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); + sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions struct ggml_hexagon_opbatch; -struct ggml_hexagon_opshm; +struct ggml_hexagon_opqueue; struct ggml_hexagon_session { std::string name; @@ -132,8 +150,8 @@ struct ggml_hexagon_session { bool valid_iface; std::atomic op_pending; - ggml_hexagon_opbatch *op_batch; - ggml_hexagon_opshm *op_shm; + ggml_hexagon_opbatch* op_batch; + ggml_hexagon_opqueue* op_queue; ggml_backend_buffer_type buffer_type = {}; ggml_backend_buffer_type repack_buffer_type = {}; @@ -1521,65 +1539,14 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf // Backend session implementation -struct ggml_hexagon_opshm { - ggml_hexagon_shared_buffer *sbuf; - - std::vector block_mask; - size_t block_size; - - uint8_t * base() const { return this->sbuf->base; } - int fd() const { return this->sbuf->fd; } - size_t n_blocks() const { return this->block_mask.size(); } - - ggml_hexagon_opshm(ggml_hexagon_session *sess, size_t max_batch, size_t max_pending) { - size_t n_bufs = HTP_OP_MAX_BUFS; - size_t n_ops = max_batch; - size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; - - block_mask.resize(max_pending, true); - - block_size = sizeof(htp_buf_desc) * n_bufs + - sizeof(htp_tensor) * n_tensors + - sizeof(htp_op_desc) * n_ops; - - sbuf = new ggml_hexagon_shared_buffer(sess, block_size * block_mask.size(), true /* pinned */); - - if (opt_verbose) { - GGML_LOG_INFO("ggml-hex: %s allocated shared buf %zu : block-size %zu max-batch %zu max-pending %zu\n", - sess->c_name(), (size_t) sbuf->size, block_size, max_batch, max_pending); - } - } - - ~ggml_hexagon_opshm() { - delete sbuf; - } - - uint8_t * allocate() { - auto it = std::find(block_mask.begin(), block_mask.end(), true); - if (it == block_mask.end()) - return nullptr; - - unsigned int i = std::distance(block_mask.begin(), it); - uint8_t* addr = sbuf->base + (i * block_size); - block_mask[i] = false; - - HEX_VERBOSE("ggml-hex: %s allocated op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); - return addr; - } - - void release(uint8_t * addr) { - int i = (addr - sbuf->base) / block_size; - block_mask[i] = true; - HEX_VERBOSE("ggml-hex: %s released op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); - } -}; - struct ggml_hexagon_opbatch { - const char* name; + ggml_hexagon_session* sess; - std::vector buffers; - std::vector tensors; - std::vector ops; + std::vector ops; // pointers to original ops + + std::vector h_bufs; // htp buffer descriptors + std::vector h_tens; // htp tensor descriptors + std::vector h_ops; // htp op descriptors std::unordered_map b_map; // buffer fd to index std::unordered_map t_map; // tensor ptr to index @@ -1606,19 +1573,21 @@ struct ggml_hexagon_opbatch { d_map.clear(); } - ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t max_batch) { - name = sess->c_name(); + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) { + this->sess = sess; n_bufs_max = HTP_OP_MAX_BUFS; - n_ops_max = max_batch; + n_ops_max = batch_size; n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; b_vmem_max = HTP_OP_MAX_VMEM; - buffers.resize(n_bufs_max); - tensors.resize(n_tens_max); ops.resize(n_ops_max); + h_bufs.resize(n_bufs_max); + h_tens.resize(n_tens_max); + h_ops.resize(n_ops_max); + b_map.reserve(n_bufs_max); t_map.reserve(n_tens_max); d_map.reserve(n_tens_max); @@ -1640,7 +1609,7 @@ struct ggml_hexagon_opbatch { b_map.insert({sbuf->fd, bi}); - htp_buf_desc &b = buffers[bi]; + htp_buf_desc &b = h_bufs[bi]; b.base = (uint64_t) sbuf->base; b.fd = sbuf->fd; b.size = sbuf->size; @@ -1664,7 +1633,7 @@ struct ggml_hexagon_opbatch { // First lookup by tensor data auto range = d_map.equal_range(t->data); for (auto it = range.first; it != range.second; ++it) { - htp_tensor * h = &tensors[it->second]; + htp_tensor * h = &h_tens[it->second]; if (same_shape(h, t)) { return it->second; } } @@ -1682,7 +1651,7 @@ struct ggml_hexagon_opbatch { uint64_t t_offset = (uint8_t *) t->data - sbuf->base; size_t t_size = ggml_nbytes(t); - htp_tensor &h = tensors[ti]; + htp_tensor &h = h_tens[ti]; h.bi = add_buffer(sbuf); h.data = t_offset; h.size = t_size; @@ -1737,65 +1706,170 @@ struct ggml_hexagon_opbatch { // assumes that fit_op() was called first and returned true void add_op(htp_op_code opcode, const struct ggml_tensor * t) { // Add new op - htp_op_desc &o = ops[n_ops++]; + + unsigned int n = n_ops++; GGML_ASSERT(n_ops <= n_ops_max); + ops[n] = t; + + htp_op_desc &o = h_ops[n]; memcpy(&o.params, &t->op_params, sizeof(t->op_params)); o.opcode = opcode; o.flags = 0; - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(name, t, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), t, o.flags); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; } o.dst = add_tensor(t); } +}; + +struct ggml_hexagon_opqueue { + // Shared buffer for storing batches + ggml_hexagon_shared_buffer *shm_buf; + size_t shm_blk_size; + + using opvec = std::vector; + + std::queue done; // completed batch ids + std::vector op_cache; // per batch op cache + std::vector start_usec; // per batch start time + + ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { + size_t n_bufs = HTP_OP_MAX_BUFS; + size_t n_ops = batch_size; + size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + + shm_blk_size = sizeof(htp_buf_desc) * n_bufs + + sizeof(htp_tensor) * n_tensors + + sizeof(htp_op_desc) * n_ops + + sizeof(htp_prof_desc) * n_ops; + + shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */); + + op_cache.resize(depth); + start_usec.resize(depth, 0); + + // init done queue + for (unsigned int i = 0; i < depth; i++) { done.push(i); } + + if (opt_verbose) { + GGML_LOG_INFO("ggml-hex: %s allocated op-queue : batch-size %zu depth %zu shm-size %zu shm-block-size %zu\n", + sess->c_name(), batch_size, depth, shm_buf->size, shm_blk_size); + } + } - size_t flush(uint8_t * mem_addr, size_t mem_size) { - static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); - static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); - static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + ~ggml_hexagon_opqueue() { + delete shm_buf; + } - const size_t b_size = sizeof(htp_buf_desc) * n_bufs; - const size_t t_size = sizeof(htp_tensor) * n_tens; - const size_t o_size = sizeof(htp_op_desc) * n_ops; + // push new batch + bool push(htp_opbatch_req& req, dspqueue_buffer& dbuf, ggml_hexagon_opbatch* op_batch) { + static_assert(sizeof(htp_opbatch_req) % 8 == 0, "sizeof(htp_opbatch_req) must be multiple of 8"); + static_assert(sizeof(htp_opbatch_rsp) % 8 == 0, "sizeof(htp_opbatch_rsp) must be multiple of 8"); + static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); + static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); + static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + static_assert(sizeof(htp_prof_desc) % 8 == 0, "sizeof(htp_prof_desc) must be multiple of 8"); - const size_t m_size = b_size + t_size + o_size; - GGML_ASSERT(m_size <= mem_size); + if (done.empty()) { return false; } - uint8_t * b_ptr = (uint8_t *) mem_addr; - uint8_t * t_ptr = (uint8_t *) b_ptr + b_size; - uint8_t * o_ptr = (uint8_t *) t_ptr + t_size; + req.id = done.front(); done.pop(); // batch id + req.n_bufs = op_batch->n_bufs; + req.n_tensors = op_batch->n_tens; + req.n_ops = op_batch->n_ops; - memcpy(b_ptr, (void *) buffers.data(), b_size); - memcpy(t_ptr, (void *) tensors.data(), t_size); - memcpy(o_ptr, (void *) ops.data(), o_size); + op_cache[req.id] = op_batch->ops; + start_usec[req.id] = ggml_time_us(); - HEX_VERBOSE("ggml-hex: %s flush-opbatch : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu\n", - name, n_bufs, n_tens, n_ops, b_vmem, b_size, t_size, o_size); + const size_t b_size = sizeof(htp_buf_desc) * req.n_bufs; + const size_t t_size = sizeof(htp_tensor) * req.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * req.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * req.n_ops; + + dbuf.ptr = shm_buf->base + (req.id * shm_blk_size); + dbuf.fd = shm_buf->fd; + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base; + dbuf.size = b_size + t_size + o_size + p_size; + + GGML_ASSERT(dbuf.size <= shm_blk_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * b_ptr = m_ptr; m_ptr += b_size; + uint8_t * t_ptr = m_ptr; m_ptr += t_size; + uint8_t * o_ptr = m_ptr; + + memcpy(b_ptr, (void *) op_batch->h_bufs.data(), b_size); + memcpy(t_ptr, (void *) op_batch->h_tens.data(), t_size); + memcpy(o_ptr, (void *) op_batch->h_ops.data(), o_size); + + HEX_VERBOSE("ggml-hex: %s op-queue push batch #%u : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu m-size %zu\n", + shm_buf->sess->c_name(), req.id, req.n_bufs, req.n_tensors, req.n_ops, op_batch->b_vmem, + b_size, t_size, o_size, (size_t) dbuf.size); + + op_batch->reset(); if (opt_verbose > 1) { htp_buf_desc *b = (htp_buf_desc*) b_ptr; - for (unsigned int i=0; i < n_bufs; i++) { - GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", name, i, + for (unsigned int i=0; i < req.n_bufs; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", shm_buf->sess->c_name(), i, b[i].fd, (void *) b[i].base, (size_t) b[i].size); } htp_tensor *t = (htp_tensor*) t_ptr; - for (unsigned int i=0; i < n_tens; i++) { + for (unsigned int i=0; i < req.n_tensors; i++) { GGML_LOG_DEBUG("ggml-hex: %s htp-tensor #%u : bi %u offset %u size %u : %zu:%zu:%zu:%zu\n", - name, i, t[i].bi, t[i].data, t[i].size, + shm_buf->sess->c_name(), i, t[i].bi, t[i].data, t[i].size, (size_t) t[i].ne[0], (size_t) t[i].ne[1], (size_t) t[i].ne[2], (size_t) t[i].ne[3]); } } - reset(); + return true; + } - return m_size; + void pop(htp_opbatch_rsp rsp, dspqueue_buffer dbuf) { + GGML_ASSERT(rsp.id < op_cache.size()); + + done.push(rsp.id); + + const size_t b_size = sizeof(htp_buf_desc) * rsp.n_bufs; + const size_t t_size = sizeof(htp_tensor) * rsp.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops; + + const size_t m_size = b_size + t_size + o_size + p_size; + GGML_ASSERT(m_size <= shm_blk_size); + + HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n", + shm_buf->sess->c_name(), rsp.id, rsp.n_bufs, rsp.n_tensors, rsp.n_ops, + (size_t) dbuf.size, b_size, t_size, o_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * p_ptr = m_ptr + (b_size + t_size + o_size); + + if (opt_profile && rsp.n_ops > 0) { + auto & ops = op_cache[rsp.id]; + + uint64_t batch_usec = ggml_time_us() - start_usec[rsp.id]; + uint32_t htp_usec = 0; + + GGML_ASSERT(rsp.n_ops <= ops.size()); + + const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr; + for (uint32_t i = 0; i < rsp.n_ops; i++) { + htp_usec += pd[i].usecs; + ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n", + shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec); + } } }; @@ -1824,17 +1898,12 @@ void ggml_hexagon_session::flush_pending(bool all) { GGML_ABORT("ggml-hex: %s dspcall : bad response : size %u dspbufs %u\n", this->c_name(), rsp_size, n_dbufs); } - op_shm->release((uint8_t*) dbuf.ptr); - if (rsp.status != HTP_STATUS_OK) { GGML_LOG_ERROR("ggml-hex: %s dspcall : dsp-rsp: %s\n", this->c_name(), status_to_str(rsp.status)); // TODO: handle errors } - // FIXME: profile will be per opreq - // this->prof_usecs = rsp.prof_usecs; - // this->prof_cycles = rsp.prof_cycles; - // this->prof_pkts = rsp.prof_pkts; + op_queue->pop(rsp, dbuf); this->op_pending--; // atomic dec @@ -1845,28 +1914,17 @@ void ggml_hexagon_session::flush_pending(bool all) { void ggml_hexagon_session::flush_batch() { if (op_batch->empty()) { return; } - htp_opbatch_req req; - req.n_bufs = op_batch->n_bufs; - req.n_tensors = op_batch->n_tens; - req.n_ops = op_batch->n_ops; + htp_opbatch_req req {}; + dspqueue_buffer dbuf{}; - dspqueue_buffer dbuf; - dbuf.fd = op_shm->fd(); - dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - dbuf.ptr = op_shm->allocate(); - if (!dbuf.ptr) { + if (!op_queue->push(req, dbuf, op_batch)) { flush_pending(false); - dbuf.ptr = op_shm->allocate(); + op_queue->push(req, dbuf, op_batch); } - dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) op_shm->base(); - dbuf.size = op_batch->flush((uint8_t*) dbuf.ptr, op_shm->block_size); - // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc - HEX_VERBOSE("ggml-hex: %s: queue-opbatch : %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); - int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); if (err != 0) { GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); @@ -2016,25 +2074,33 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } if (opt_etm) { - err = htp_iface_enable_etm(this->handle); + err = htp_iface_etm(this->handle, 1); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err); } } - // Start the DSP-side service. We need to pass the queue ID to the - // DSP in a FastRPC call; the DSP side will import the queue and start - // listening for packets in a callback. + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + std::copy(opt_pmu_evt.begin(), opt_pmu_evt.end(), pmu_conf.events); + + err = htp_iface_profiler(this->handle, opt_profile, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to enable profiling: 0x%08x\n", (unsigned) err); + } + } + + // Allocate buffers and state for op batching + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); + this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); + + // Start processing op batch requests err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; - - // Allocate buffers and state for op batching - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); - this->op_shm = new ggml_hexagon_opshm(this, opt_opbatch, opt_opqueue); } void ggml_hexagon_session::release() noexcept(true) { @@ -2043,7 +2109,7 @@ void ggml_hexagon_session::release() noexcept(true) { int err; delete this->op_batch; - delete this->op_shm; + delete this->op_queue; // Stop the DSP-side service and close the queue if (this->valid_iface) { @@ -2054,12 +2120,20 @@ void ggml_hexagon_session::release() noexcept(true) { } if (opt_etm) { - err = htp_iface_disable_etm(this->handle); + err = htp_iface_etm(this->handle, 0); if (err != 0) { GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err); } } + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + err = htp_iface_profiler(this->handle, 0, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: warn : failed to disable profiling: 0x%08x\n", (unsigned) err); + } + } + if (this->valid_queue) { err = dspqueue_close(queue); if (err != 0) { @@ -2077,7 +2151,7 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n repack_buffer_type.device = dev; op_batch = nullptr; - op_shm = nullptr; + op_queue = nullptr; try { allocate(dev_id); @@ -2596,6 +2670,62 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // diag only supports F32 currently + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Input must have ne[1] == 1 (vector input) + if (src0->ne[1] != 1) { + return false; + } + + // Output must be square in first two dimensions + if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + +static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // A + const struct ggml_tensor * src1 = op->src[1]; // B + const struct ggml_tensor * dst = op; // X + + if (!src0 || !src1) { + return false; + } + + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (src0->ne[0] != src0->ne[1]) { + return false; + } + + if (src0->ne[1] != src1->ne[1]) { + return false; + } + + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return false; + } + + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2632,7 +2762,9 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; - + case GGML_OP_FILL: return HTP_OP_FILL; + case GGML_OP_DIAG: return HTP_OP_DIAG; + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -2673,7 +2805,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * n = graph->nodes[i]; - if (op_is_compute(n)) { + if (op_is_compute(n) && (opt_opstage & HTP_OPSTAGE_QUEUE)) { sess->enqueue_op(op_remap_to_htp(n), n); } } @@ -3029,6 +3161,17 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * dst = op; + + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -3159,6 +3302,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_FILL: + supp = ggml_hexagon_supported_fill(sess, op); + break; + + case GGML_OP_DIAG: + supp = ggml_hexagon_supported_diag(sess, op); + break; + + case GGML_OP_SOLVE_TRI: + supp = ggml_hexagon_supported_solve_tri(sess, op); + break; + default: break; } @@ -3294,6 +3449,26 @@ static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, cons return NULL; } +template std::vector str_to_vec(const char* str) { + std::stringstream ss(str); + std::vector v; + std::string t; + + while (std::getline(ss, t, ',')) { + v.push_back(std::stoul(t, nullptr, 0)); + } + + return v; +} + +template std::string vec_to_str(std::vector v) { + std::stringstream ss; + ss << std::setbase(BASE) << std::showbase; + for (auto i : v) { ss << i << ','; } + auto str = ss.str(); str.pop_back(); // drop last comma + return str; +} + static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, @@ -3307,8 +3482,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); - const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); @@ -3321,19 +3495,30 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { auto RE_ICASE = std::regex_constants::icase; - opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; - opt_opsync = str_opsync ? atoi(str_opsync) : opt_opsync; - opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; - opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_profile = str_profile ? atoi(str_profile) : 0; - opt_etm = str_etm ? atoi(str_etm) : 0; - opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; - opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + + if (str_profile) { + opt_pmu_evt = [&]() -> std::vector { + auto v = str_to_vec(str_profile); + switch (v.size()) { + case 1: opt_profile = v[0]; return opt_pmu_evt; // mode with default pmu events + case 8: opt_profile = 2; return v; // mode with custom pmu events + default: opt_profile = 0; return {}; // garbage input + }}(); + if (opt_profile == 1) opt_pmu_evt = {}; + GGML_LOG_INFO("ggml-hex: Profiling mode %u : pmu-evt [ %s ]\n", opt_profile, + vec_to_str(opt_pmu_evt).c_str()); + } if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { opt_ndev = GGML_HEXAGON_MAX_SESSIONS; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 9ca759459d4..8bd528478ba 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,9 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + fill-ops.c + diag-ops.c + solve-tri-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/diag-ops.c b/ggml/src/ggml-hexagon/htp/diag-ops.c new file mode 100644 index 00000000000..9b3194d9084 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/diag-ops.c @@ -0,0 +1,216 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hex-utils.h" +#include "hvx-copy.h" +#include "hex-dma.h" + +#define htp_diag_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne02 = src0->ne[2]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_diag_context { + struct htp_ops_context * octx; + size_t src_batch_size; + size_t dst_row_size; + size_t src_batch_size_aligned; + size_t dst_row_size_aligned; + uint32_t batches_per_thread; + uint32_t total_batches; +}; + +#define htp_diag_preamble \ + struct htp_diag_context * dctx = (struct htp_diag_context *) data; \ + struct htp_ops_context * octx = dctx->octx; \ + htp_diag_tensors_preamble; + +static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst, + uint32_t row_idx, uint32_t n) { + hvx_splat_f32_a((uint8_t *) dst, 0.0f, n); + dst[row_idx] = src[row_idx]; +} + +// --------------------------------------------------------------------------- +// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback +// --------------------------------------------------------------------------- + +static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + dma_queue * dma_queue = octx->ctx->dma[ith]; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + if (ib0 >= ib1) { + return; + } + + const size_t src_batch_size = dctx->src_batch_size; + const size_t dst_row_size = dctx->dst_row_size; + const size_t src_batch_size_aligned = dctx->src_batch_size_aligned; + const size_t dst_row_size_aligned = dctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + // 1 src buffer + 1 dst row buffer per thread in VTCM + uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02; + + // Fetch source vector into VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad, src_batch), + src_batch_size_aligned, src_batch_size, 1); + dma_queue_flush(dma_queue); + + const float * src_spad_f32 = (const float *) src_spad; + float * dst_spad_f32 = (float *) dst_spad; + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + // Compute row in VTCM + hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0); + + // Write completed row back to DDR + uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_row, dst_spad), + dst_row_size, dst_row_size_aligned, 1); + dma_queue_flush(dma_queue); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1); + hvx_diag_row_f32(src_batch, dst_row, i1, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_diag_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_batches); + + const size_t src_batch_size = src0->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 1 src buffer + 1 dst row buffer per thread + const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned; + + octx->src0_spad.size_per_thread = src_batch_size_aligned; + octx->dst_spad.size_per_thread = dst_row_size_aligned; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_diag_context dctx = { + .octx = octx, + .src_batch_size = src_batch_size, + .dst_row_size = dst_row_size, + .src_batch_size_aligned = src_batch_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .batches_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_batches = total_batches, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_diag(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_diag_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/fill-ops.c b/ggml/src/ggml-hexagon/htp/fill-ops.c new file mode 100644 index 00000000000..3ccfbe74ee4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/fill-ops.c @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-copy.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +// ggml op_params layout for FILL: +// op_params[0] (as float) - the scalar fill value + +#define fill_preamble \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne1 * ne2 * ne3; + +struct htp_fill_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint32_t total_rows; // ne1 * ne2 * ne3 + bool opt_path; + HVX_Vector splat_vec; + uint32_t elem_size; +}; + +static void fill_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_fill_context * fctx = (const struct htp_fill_context *) data; + struct htp_ops_context * octx = fctx->octx; + fill_preamble; + + // Parallelise over the flat row index spanning ne1*ne2*ne3 + const uint32_t ir0 = fctx->nrows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows); + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + if (fctx->opt_path) { + // Opt path: tensor is fully contiguous, treat as flat array + const uint32_t elem_start = ir0 * ne0; + const uint32_t elem_end = ir1 * ne0; + uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size; + hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size); + } else { + // Non-contiguous path: must respect strides + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3; + hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size); + } + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n", + ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_fill(struct htp_ops_context * octx) { + fill_preamble; + + if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it. + const uint32_t n_threads = MIN(nr, octx->n_threads); + + // Optimize if fully contiguous: skip stride arithmetic, treat as flat array + const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2); + + FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n", + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path); + + float val_f32 = 0.f; + memcpy(&val_f32, &octx->op_params[0], sizeof(float)); + + struct htp_fill_context fctx = { + .octx = octx, + .nrows_per_thread = (nr + n_threads - 1) / n_threads, + .total_rows = nr, + .opt_path = opt_path, + }; + + switch (dst->type) { + case HTP_TYPE_F32: + fctx.splat_vec = hvx_vec_splat_f32(val_f32); + fctx.elem_size = sizeof(float); + break; + case HTP_TYPE_F16: + fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32); + fctx.elem_size = sizeof(_Float16); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index f6713c5cf8f..329249e11da 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "hexagon_types.h" #include "hexagon_protos.h" @@ -100,4 +101,31 @@ static inline void hex_pause() { asm volatile(" pause(#255)\n"); } +#ifndef HEX_NUM_PMU_COUNTERS +#define HEX_NUM_PMU_COUNTERS 8 +#endif + +static inline void hex_get_pmu(uint32_t counters[]) { +#if __HVX_ARCH__ >= 79 + asm volatile("%0 = upmucnt0" : "=r"(counters[0])); + asm volatile("%0 = upmucnt1" : "=r"(counters[1])); + asm volatile("%0 = upmucnt2" : "=r"(counters[2])); + asm volatile("%0 = upmucnt3" : "=r"(counters[3])); + asm volatile("%0 = upmucnt4" : "=r"(counters[4])); + asm volatile("%0 = upmucnt5" : "=r"(counters[5])); + asm volatile("%0 = upmucnt6" : "=r"(counters[6])); + asm volatile("%0 = upmucnt7" : "=r"(counters[7])); +#else + counters[0] = qurt_pmu_get(QURT_PMUCNT0); + counters[1] = qurt_pmu_get(QURT_PMUCNT1); + counters[2] = qurt_pmu_get(QURT_PMUCNT2); + counters[3] = qurt_pmu_get(QURT_PMUCNT3); + counters[4] = qurt_pmu_get(QURT_PMUCNT4); + counters[5] = qurt_pmu_get(QURT_PMUCNT5); + counters[6] = qurt_pmu_get(QURT_PMUCNT6); + counters[7] = qurt_pmu_get(QURT_PMUCNT7); + // qurt_pmu_get_pmucnt(counters); +#endif +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 8b5e47adef8..d704fedee9d 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -10,6 +10,7 @@ #include #include #include +#include #define HTP_MAX_NTHREADS 10 #define HTP_MAX_MMAPS 16 @@ -66,7 +67,9 @@ struct htp_context { int thread_id; int thread_prio; - int hmx_enabled; + bool hmx_enabled; + bool etm; + uint32_t profiler; uint8_t * vtcm_base; size_t vtcm_size; @@ -98,5 +101,8 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_fill(struct htp_ops_context * octx); +int op_diag(struct htp_ops_context * octx); +int op_solve_tri(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 79b5ecd2270..4397245c5b8 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -42,9 +42,9 @@ enum htp_data_type { // Mask to enable various stages of the Ops. // Used for debugging and profiling. -enum htp_op_mask { - HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) - HTP_OPMASK_COMPUTE = (1 << 1), // Enable Compute +enum htp_op_stage { + HTP_OPSTAGE_QUEUE = (1 << 0), // Enable Queueing (ie calls into NPU) + HTP_OPSTAGE_COMPUTE = (1 << 1), // Enable Compute }; // Do not reorder first 4 (used as an index) @@ -80,7 +80,9 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, - + HTP_OP_FILL, + HTP_OP_DIAG, + HTP_OP_SOLVE_TRI, HTP_OP_INVALID }; @@ -135,27 +137,45 @@ struct htp_op_desc { int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices uint16_t dst; // Output tensor index +}; + +enum htp_profiler_mode { + HTP_PROF_DISABLED = 0, + HTP_PROF_BASIC = 1, + HTP_PROF_PMU = 2, +}; + +#define HTP_PROF_PMU_NCNT 8 - // the rest is filled in-place by the NPU - uint32_t prof_usecs; // Number of usec per request - uint32_t prof_cycles; // Number of cycles per request - uint32_t prof_pkts; // Number of instruction packets per request - uint32_t unused; +// Profile descriptor +struct htp_prof_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t usecs; // Number of usec + uint32_t cycles; // Number of cycles + uint32_t pad; // Unused + uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters }; struct htp_opbatch_req { + uint32_t id; // Batch id uint32_t n_bufs; // Number of buffers uint32_t n_tensors; // Number of tensors uint32_t n_ops; // Number of ops uint32_t flags; // unused + uint32_t pad; // unused // struct htp_buf_desc bufs[]; -- dspqueue buf 0 // struct htp_tensor tensors[]; -- dspqueue buf 0 // struct htp_op_desc ops[]; -- dspqueue buf 0 }; struct htp_opbatch_rsp { + uint32_t id; // Batch id uint32_t status; // HTP_STATUS_... - // struct htp_op_req ops[]; -- dspqueue buf 0 + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of op profile descriptors + uint32_t pad; // unused + // struct htp_prof_desc profs[]; -- dspqueue buf 0 }; #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 3eb5d5a6912..dbcafd1d856 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -6,13 +6,17 @@ #include "AEEStdDef.idl" #include "remote.idl" +struct htp_iface_pmu_conf { + uint32 events[8]; +}; + interface htp_iface : remote_handle64 { AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); AEEResult stop(); AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); AEEResult munmap(in uint32 fd); - AEEResult enable_etm(); - AEEResult disable_etm(); + AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); + AEEResult etm(in uint32 enable); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ed6026e762a..d0926dedd28 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -256,6 +256,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)); +} + #else static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) @@ -273,6 +285,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_vmpy_VhfVhf(a, b); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vsub_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vmpy_VsfVsf(a, b); +} + #endif // __HVX_ARCH__ < 79 #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 5091623a653..db277a25e5a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -27,6 +27,7 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" +#include "htp_iface.h" #include "worker-pool.h" AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { @@ -103,6 +104,54 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return AEE_SUCCESS; } +AEEResult htp_iface_etm(remote_handle64 handle, uint32_t enable) { + int err = enable ? HAP_user_etm_enable() : HAP_user_etm_disable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_enable/disable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_enable/disable with error code : 0x%x\n", err); + } + } + return err; +} + +AEEResult htp_iface_profiler(remote_handle64 handle, uint32_t mode, const htp_iface_pmu_conf* pmu_conf) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + if (mode == HTP_PROF_PMU) { + const uint32_t* events = pmu_conf->events; + + // Pack 4 event IDs (low 8 bits) into each 32-bit config register + uint32_t evtcfg = 0, evtcfg1 = 0, cfg = 0, i = 0; + for (; i < HEX_NUM_PMU_COUNTERS/2; i++) { + evtcfg |= ((events[i + 0] & 0xFF) << (i * 8)); + evtcfg1 |= ((events[i + 4] & 0xFF) << (i * 8)); + } + + // For events >255 pack high 2 bits of all 8 event IDs into cfg register + // 2 bits per counter: bits [1:0] for counter 0, [3:2] for counter 1, etc. + for (i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + cfg |= (((events[i] >> 8) & 3) << (i * 2)); + } + + FARF(ALWAYS, "Configuring PMU registers: evtcfg = 0x%x, evtcfg1 = 0x%x, pmucfg = 0x%x", evtcfg, evtcfg1, cfg); + + // Configure PMU registers + qurt_pmu_set(QURT_PMUCFG, cfg); + qurt_pmu_set(QURT_PMUEVTCFG, evtcfg); + qurt_pmu_set(QURT_PMUEVTCFG1, evtcfg1); + qurt_pmu_enable(1); + } + + ctx->profiler = mode; + + return AEE_SUCCESS; +} + AEEResult htp_iface_close(remote_handle64 handle) { struct htp_context * ctx = (struct htp_context *) handle; @@ -129,35 +178,19 @@ AEEResult htp_iface_close(remote_handle64 handle) { } } - free(ctx); - return AEE_SUCCESS; -} - -AEEResult htp_iface_enable_etm(remote_handle64 handle) { - int err = HAP_user_etm_enable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_enable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err); - } + if (ctx->profiler) { + qurt_pmu_enable(1); } - return err; -} -AEEResult htp_iface_disable_etm(remote_handle64 handle) { - int err = HAP_user_etm_disable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_disable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err); - } + if (ctx->etm) { + HAP_user_etm_disable(); } - return err; + + free(ctx); + return AEE_SUCCESS; } -AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t pinned) { +AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -204,7 +237,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t return AEE_ENOMEMORY; } -AEEResult htp_iface_munmap(remote_handle64 handle, int fd) { +AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -434,19 +467,39 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) { struct profile_data { uint64_t usecs; uint64_t cycles; - uint64_t pkts; + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; }; -static inline void profile_start(struct profile_data * d) { - d->usecs = HAP_perf_get_qtimer_count(); - d->cycles = hex_get_cycles(); - d->pkts = hex_get_pktcnt(); +static inline void profile_start(uint32_t mode, struct profile_data * d) { + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(d->pmu_counters); + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_get_qtimer_count(); + d->cycles = hex_get_cycles(); + break; + default: + break; + } } -static inline void profile_stop(struct profile_data * d) { - d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); - d->cycles = hex_get_cycles() - d->cycles; - d->pkts = hex_get_pktcnt() - d->pkts; +static inline void profile_stop(uint32_t mode, struct profile_data * d) { + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(pmu_counters); + for (int i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + d->pmu_counters[i] = pmu_counters[i] - d->pmu_counters[i]; + } + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); + d->cycles = hex_get_cycles() - d->cycles; + break; + default: + break; + } } static int execute_op(struct htp_ops_context * octx) { @@ -514,6 +567,15 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_FILL: + return op_fill(octx); + + case HTP_OP_DIAG: + return op_diag(octx); + + case HTP_OP_SOLVE_TRI: + return op_solve_tri(octx); + case HTP_OP_INVALID: break; @@ -720,29 +782,32 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { continue; } + // Reset poll count for valid requests + poll_count = DSPQUEUE_POLL_COUNT; + const uint32_t n_bufs = req.n_bufs; const uint32_t n_tens = req.n_tensors; const uint32_t n_ops = req.n_ops; - const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; - const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; - const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; + const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; + const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops; - if (dbuf.size < b_size + t_size + o_size) { + if (dbuf.size < b_size + t_size + o_size + p_size) { FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size); break; } - // Reset poll count for valid requests - poll_count = DSPQUEUE_POLL_COUNT; + FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id, + n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + // Setup descriptor pointers uint8_t * m_ptr = dbuf.ptr; - struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; - struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; - struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; - - FARF(HIGH, "processing opbatch: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", - n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; + struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; + struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; m_ptr += o_size; + struct htp_prof_desc* pds = (struct htp_prof_desc*) m_ptr; prep_op_bufs(ctx, bufs, n_bufs); prep_tensors(ctx, bufs, tens, n_tens); @@ -754,22 +819,34 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; - profile_start(&prof); + + profile_start(ctx->profiler, &prof); proc_op_req(octx, tens, i, &ops[i]); - profile_stop(&prof); - ops[i].prof_usecs = prof.usecs; - ops[i].prof_cycles = prof.cycles; - ops[i].prof_pkts = prof.pkts; + profile_stop(ctx->profiler, &prof); + + if (ctx->profiler) { + pds[i].opcode = ops[i].opcode; + pds[i].usecs = prof.usecs; + pds[i].cycles = prof.cycles; + for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) { + pds[i].pmu[j] = prof.pmu_counters[j]; + } + } } // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); struct htp_opbatch_rsp rsp; - rsp.status = HTP_STATUS_OK; // FIXME + rsp.id = req.id; + rsp.status = HTP_STATUS_OK; + rsp.n_bufs = n_bufs; + rsp.n_tensors = n_tens; + rsp.n_ops = n_ops; dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + err = dspqueue_write(queue, 0, 1, &dbuf, sizeof(rsp), (const uint8_t *) &rsp, DSPQUEUE_TIMEOUT_NONE); if (err != 0) { FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index bac06693d81..a0c265132c8 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -3017,6 +3017,10 @@ int op_matmul(struct htp_ops_context * octx) { const int act_stride = (int)(src1->nb[1] / sizeof(float)); const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + if (src0->type == HTP_TYPE_F16) { if (is_batched) { hmx_matmul_w16a32_batched_params_t batch_params = { diff --git a/ggml/src/ggml-hexagon/htp/solve-tri-ops.c b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c new file mode 100644 index 00000000000..ae8e1a50495 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" + +struct htp_solve_tri_context { + struct htp_ops_context * octx; + uint32_t jobs_per_thread; + uint32_t total_jobs; + uint32_t k_chunks; + uint32_t col_block; +}; + +static inline void solve_tri_row_scalar(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + for (uint32_t col = col0; col < col0 + coln; ++col) { + float sum = 0.0f; + for (uint32_t t = 0; t < row; ++t) { + sum += A_row[t] * X[t * k + col]; + } + X[row * k + col] = (B_row[col] - sum) * inv_diag; + } +} + +static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) { + HVX_Vector v = *((const HVX_UVector *) src); + HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float)); + return Q6_V_vmux_QVV(mask, v, Q6_V_vzero()); +} + +static inline void solve_tri_row_hvx(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + const bool full = (coln == VLEN_FP32); + + HVX_Vector sum_v = Q6_V_vzero(); + for (uint32_t t = 0; t < row; ++t) { + const float a = A_row[t]; + const float * x_row_col = X + t * k + col0; + + HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln); + HVX_Vector a_v = hvx_vec_splat_f32(a); + sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v)); + } + + const float * b_row_col = B_row + col0; + float * x_out_col = X + row * k + col0; + + HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln); + HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag); + + HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v); + hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v); +} + +// Batch-level thread: each job is one full batch. +static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_full = (k / col_block) * col_block; + + const uint32_t start_batch = sctx->jobs_per_thread * ith; + const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t batch = start_batch; batch < end_batch; ++batch) { + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + uint32_t col0 = 0; + for (; col0 < k_full; col0 += col_block) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag); + } + + if (col0 < k) { + const uint32_t coln = k - col0; + if (coln >= 8) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n", + ith, nth, n, n, k, n, start_batch, end_batch, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// Chunk-level thread: each job is one (batch, col_chunk) pair. +static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t start_job = sctx->jobs_per_thread * ith; + const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t job = start_job; job < end_job; ++job) { + const uint32_t batch = job / sctx->k_chunks; + const uint32_t chunk = job - batch * sctx->k_chunks; + + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const uint32_t col0 = chunk * sctx->col_block; + const uint32_t coln = MIN(sctx->col_block, k - col0); + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + const bool use_hvx = (coln >= 8); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + if (use_hvx) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n", + ith, nth, n, n, k, n, start_job, end_job, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_solve_tri(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // left=true, lower=true, uni=false only + if (src0->ne[0] != src0->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[1] != src1->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || + dst->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t k = src1->ne[0]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_chunks = (k + col_block - 1) / col_block; + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const bool batched = total_batches >= (uint32_t) octx->n_threads; + + FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched); + + if (batched) { + // Batch-level parallelism + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_jobs = total_batches, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads); + } else { + // Chunk-level parallelism + const uint32_t total_jobs = total_batches * k_chunks; + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1)); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_jobs + n_threads - 1) / n_threads, + .total_jobs = total_jobs, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 656d2d9ab26..360d8b1228e 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -18,6 +18,7 @@ libggml-htp-v68.so = 1 libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 +libggml-htp-v79.so = 1 libggml-htp-v81.so = 1 [ControlFlags] @@ -31,6 +32,7 @@ libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE [Strings] diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 27cb1683518..27b78c5e6d7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -814,7 +814,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { } // print MTL GPU family: - GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); + GGML_LOG_INFO("%s: GPU name: %s (%s)\n", __func__, dev->props.name, dev->props.desc); // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -931,13 +931,13 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { } struct ggml_metal_event { - void * obj; // id + void * obj; // id atomic_int value; }; void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -945,7 +945,7 @@ void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t } void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -953,7 +953,7 @@ void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cm } ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { - id event = [dev->mtl_device newEvent]; + id event = [dev->mtl_device newSharedEvent]; ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); @@ -964,7 +964,7 @@ ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { } void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { - id event = ev->obj; + id event = ev->obj; [event release]; free(ev); @@ -973,14 +973,13 @@ void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev } void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { - @autoreleasepool { - id event = ev->obj; - - id cmd_buf = [dev->mtl_queue commandBuffer]; - [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; - [cmd_buf commit]; - [cmd_buf waitUntilCompleted]; + id event = ev->obj; + const bool res = [event waitUntilSignaledValue:atomic_load_explicit(&ev->value, memory_order_relaxed) timeoutMS:60000]; + if (!res) { + GGML_ABORT("%s: failed to wait for event\n", __func__); } + + GGML_UNUSED(dev); } void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 4dbf8e6fea9..6a836e45908 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -918,6 +918,10 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) { static std::vector devs; if (!initialized) { + // workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible + // ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703 + setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true); + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); for (int i = 0; i < g_devices; ++i) { diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0938d2273e9..5095e799849 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -207,8 +206,22 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { break; } case GGML_OP_ROPE: { + const int mode = node->op_params[2]; + switch (mode) { + case GGML_ROPE_TYPE_NEOX: { + op_case = 0x00010000; + break; + } + case GGML_ROPE_TYPE_IMROPE: { + op_case = 0x00020000; + break; + } + default: + op_case = 0x00000000; + break; + } if (node->src[0]->op == GGML_OP_VIEW) { - op_case = 2; + op_case = (op_case | 0x00000002); } break; } @@ -573,9 +586,6 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const } std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) { - static std::mutex weights_mutex; - std::lock_guard lock(weights_mutex); - std::map> model_weights; auto * nodes = cgraph->nodes; auto n_nodes = cgraph->n_nodes; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index cc3cb4583cd..4140136aca2 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include ov::Core & ov_singleton_core() { @@ -42,11 +43,13 @@ void ggml_openvino_device_config::init() { {"NPUW_DQ", "YES" }, {"NPUW_DQ_FULL", "NO" }, }; - if (cache_dir) { + if (cache_dir && strlen(cache_dir) > 0) { compile_config["NPUW_CACHE_DIR"] = cache_dir; + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } - } else if (cache_dir) { - ov_singleton_core().set_property(ov::cache_dir(cache_dir)); + } else if (cache_dir && strlen(cache_dir) > 0) { + compile_config.insert(ov::cache_dir(cache_dir)); + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } // Initialize remote context with queue sharing for GPU @@ -259,10 +262,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); - // For symmetric quantization, we only need one zp value (not one per block) - // Zero points are stored in U4 or U8 format matching the weight type - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; @@ -313,10 +318,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Scales: F16 per block int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - // Zero points: U4 or U8 matching weight type - // For symmetric quantization, we only need one zp value (not one per block) - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } // Layout in buffer: [weights | scales | zp] with alignment layout.weights_offset = 0; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 0c8d3508e87..4f3ebf2536b 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -145,13 +145,18 @@ static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer return ctx->data; } +static bool is_stateful_enabled() { + static const auto * stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION"); + return stateful && *stateful != '\0' && strcmp(stateful, "0") != 0; +} + static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && - !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { + !is_stateful_enabled()) { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; @@ -600,6 +605,14 @@ bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { static void ggml_backend_openvino_free(ggml_backend_t backend) { ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + + if (ctx->runtime_context) { + auto r_ctx = std::static_pointer_cast(ctx->runtime_context); + if (--r_ctx->backend_count == 0) { + r_ctx->clear_caches(); + } + } + delete ctx; delete backend; } @@ -644,7 +657,12 @@ static ggml_guid_t ggml_backend_openvino_guid(void) { } static std::shared_ptr get_ov_runtime_context_ptr() { - static std::shared_ptr r_ctx = std::make_shared(); + static std::shared_ptr r_ctx = [] { + auto ctx = std::make_shared(); + ctx->device = ggml_openvino_get_device_name(); + ctx->stateful = is_stateful_enabled() && !ggml_openvino_is_npu(); + return ctx; + }(); return r_ctx; } @@ -669,8 +687,7 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { } std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context); - r_ctx->device = ggml_openvino_get_device_name(); - r_ctx->stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !ggml_openvino_is_npu(); + r_ctx->backend_count++; ggml_backend_t openvino_backend = new ggml_backend{ /* .guid = */ ggml_backend_openvino_guid(), @@ -883,7 +900,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { const int32_t * op_params = op->op_params; const int n_dims = op_params[1]; const int mode = op_params[2]; - if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_IMROPE) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); return true; } @@ -896,14 +913,6 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); return true; } - float freq_scale; - float ext_factor; - memcpy(&freq_scale, op_params + 6, sizeof(float)); - memcpy(&ext_factor, op_params + 7, sizeof(float)); - if (ext_factor != 0.0f) { - // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); - return true; - } if (op->src[0]->op == GGML_OP_VIEW) { if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { // GGML_LOG_WARN( @@ -913,6 +922,12 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { return true; } } + if (mode == GGML_ROPE_TYPE_IMROPE && + (op->src[2] != 0 || ((const float *) op_params)[6] != 1 || ((const float *) op_params)[7] != 0 || + ((const float *) op_params)[8] != 1)) { + // GGML_LOG_WARN("OpenVINO backend does not support IMROPE with freq_factors, freq_scale, ext_factor, and attn_factor\n"); + return true; + } break; } default: @@ -942,6 +957,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con // GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; static const std::set supported_unary_ops{ + GGML_UNARY_OP_GELU, GGML_UNARY_OP_SILU, }; static const std::set supported_glu_ops{ diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index dbf38646ddd..57d66df4f01 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -46,6 +46,7 @@ void unpack_32_4(const uint8_t * data, uint8_t * dst) { // Extracts (weight, scales, zp) from Q4_0 tensors. // Data layout is: |16 bit scale|32 x 4bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i4 (value - 8). void extract_q4_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -55,28 +56,32 @@ void extract_q4_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - // For asymmetric quantization, compute per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); // Pack two 4-bit zero points per byte if (i % 2 == 0) { zp[i / 2] = 8; // Lower nibble } else { zp[i / 2] |= (8 << 4); // Upper nibble } - } - unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); - }); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); + } else { + // Symmetric: unpack as u4 then convert to i4 by subtracting 8 (XOR each nibble) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + // Convert u4 to i4: subtract 8 from each nibble. XOR 0x88 flips each nibble by 8. + for (int j = 0; j < 16; ++j) { + weights[i * 16 + j] ^= 0x88; + } + }); + } } // Extracts (weight, scales, zp) from Q4_1 tensors. @@ -123,6 +128,7 @@ void extract_q4_1_data(const ggml_tensor * tensor, // Extracts (weight, scales, zp) from Q8_0 tensors. // Data layout is: |16 bit scale|32 x 8bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i8 directly. void extract_q8_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -133,29 +139,30 @@ void extract_q8_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); zp[i] = 128; - } - for (size_t j = 0; j < weights_per_block; ++j) { - uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the first bit. - x ^= 1 << 7; - weights[i * weights_per_block + j] = x; - } - }); + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; + x ^= 1 << 7; // Convert int8 to uint8 by flipping sign bit + weights[i * weights_per_block + j] = x; + } + }); + } else { + // Symmetric: store original int8 values directly (no unsigned bias) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + // Copy int8 weights as-is (the tensor element type is i8) + memcpy(weights + i * weights_per_block, block_data + 2, weights_per_block); + }); + } } void unpack_256_4(const uint8_t * data, uint8_t * dst) { @@ -256,44 +263,62 @@ void extract_q6_k_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q6_K, zero point is always 32 - if (is_scalar_zp) { - zp[0] = 32; - } - - ov::parallel_for(n_super_block, [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - float scale_factor = - static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); // (128+64+16)/2 + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (size_t j = 0; j < 16; j++) { - scales[j + i * 16] = - ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); zp[j + i * 16] = 32; } - } - - uint8_t * ql = block_data; - uint8_t * qh = block_data + 128; - - for (int64_t j = 0; j < 32; ++j) { - weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); - weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); - weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); - weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); - weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); - weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); - weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); - weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); - } - }); + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); + } else { + // Symmetric: subtract 32 from each weight to store as signed i8 + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); + } + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + auto * signed_weights = reinterpret_cast(weights); + for (int64_t j = 0; j < 32; ++j) { + signed_weights[i * 256 + j] = static_cast((ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 32] = + static_cast((ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 64] = static_cast((ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 96] = + static_cast((ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 128] = + static_cast((ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 160] = + static_cast((ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 192] = + static_cast((ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 224] = + static_cast((ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4)) - 32; + } + }); + } } static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { @@ -389,11 +414,10 @@ ov::Output make_int8_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i8); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias auto scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; @@ -403,37 +427,48 @@ ov::Output make_int8_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - // Create graph nodes - auto weights_node = std::make_shared(ov::element::u8, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_point = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_point, zp_value)) { - zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_point = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -452,11 +487,10 @@ ov::Output make_int4_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_weight_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i4); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias ov::Shape scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization // Create INT4 weight tensor ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; @@ -467,36 +501,48 @@ ov::Output make_int4_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - auto weights_node = std::make_shared(ov::element::u4, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); auto scales_f16 = std::make_shared(scales); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_points_node = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_points_node, zp_value)) { - zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_points_node = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -699,24 +745,32 @@ OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, vo // Quantized path (normal extraction or quantized requant) // Create weight/scale/zp tensors - shared between both paths - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + // For symmetric quantization, use signed types (i4/i8) and no ZP tensor + ov::element::Type weight_type = layout.is_symmetric ? (layout.is_u4 ? ov::element::i4 : ov::element::i8) : + (layout.is_u4 ? ov::element::u4 : ov::element::u8); ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + if (!layout.is_symmetric) { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape, buf_base + layout.zp_offset); + } + // else: result.zp remains default-constructed (empty) for symmetric } else { result.weights = ov::Tensor(weight_type, node_shape); result.scales = ov::Tensor(ov::element::f16, scale_shape); - if (use_bias && !layout.is_symmetric) { - // bias only has effect for asymmetric quant - result.zp = ov::Tensor(ov::element::f16, zp_shape); - } else { - result.zp = ov::Tensor(weight_type, zp_shape); + if (!layout.is_symmetric) { + if (use_bias) { + result.zp = ov::Tensor(ov::element::f16, scale_shape); + } else { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape); + } } + // else: result.zp remains default-constructed (empty) for symmetric } if (layout.is_requant && layout.requant_type.has_value()) { @@ -741,59 +795,75 @@ void quantize_q4_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } } - } - - const float d = max / -8; - - if (d == 0) { - scales[i] = ov::float16(1.0f); - // zp is already set to 8 for symmetric, or set per-block for asymmetric - if (!is_scalar_zp) { + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); + continue; } - memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); - continue; - } - - const float id = 1.0f / d; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float id = 1.0f / d; + scales[i] = ov::float16(d); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } } - - for (int j = 0; j < qk / 2; ++j) { - const float x0 = x[i * qk + 2 * j] * id; - const float x1 = x[i * qk + 2 * j + 1] * id; - const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); - weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } else { + // Symmetric: produce signed i4 values in [-8, 7] + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); + // i4 value 0 packed: 0x00 + memset(weights + i * qk / 2, 0, qk / 2); + continue; + } + const float id = 1.0f / d; + scales[i] = ov::float16(d); + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + // Signed i4: range [-8, 7]. Quantize as round(x*id), then pack as 4-bit two's complement. + int8_t si0 = (int8_t) std::max(-8, std::min(7, (int) roundf(x0))); + int8_t si1 = (int8_t) std::max(-8, std::min(7, (int) roundf(x1))); + weights[i * qk / 2 + j] = (si0 & 0x0F) | ((si1 & 0x0F) << 4); + } } } } @@ -809,36 +879,42 @@ void quantize_q8_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); } - } - - const float d = amax / 127.0f; - const float id = d ? 1.0f / d : 0.0f; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); zp[i] = 128; + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); + } } - - for (int j = 0; j < qk; ++j) { - const float x0 = x[i * qk + j] * id; - const int8_t xi0 = roundf(x0); - weights[i * qk + j] = (uint8_t) (xi0 + 128); + } else { + // Symmetric: store signed int8 values directly + auto * signed_weights = reinterpret_cast(weights); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + signed_weights[i * qk + j] = (int8_t) roundf(x0); + } } } } @@ -861,12 +937,8 @@ void quantize_q8_1(const float * x, for (int j = 0; j < qk; j++) { const float v = x[i * qk + j]; - if (v < min) { - min = v; - } - if (v > max) { - max = v; - } + min = std::min(v, min); + max = std::max(v, max); } const float d = (max - min) / ((1 << 8) - 1); diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 26dc2d24f82..a8db9b38930 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -9,12 +9,17 @@ #include #include #include +#include +#include +#include #include #include #include +#include #include #include #include +#include #include #include @@ -33,6 +38,12 @@ OutputVector translate_rope(const NodeContext & context) { auto data_node = context.get_input(0).get_node_shared_ptr(); auto output_shape = context.get_output_shape().to_shape(); int32_t * op_params = context.get_output_op_params(); + const int mode = (op_case & 0xFFFF0000) >> 16; + op_case = (op_case & 0x0000FFFF); + + constexpr int TYPE_NORMAL = 0; + constexpr int TYPE_NEOX = 1; + constexpr int TYPE_IMROPE = 2; Output cos_theta_node; Output sin_theta_node; @@ -45,7 +56,7 @@ OutputVector translate_rope(const NodeContext & context) { if (context.get_input_size() == 3) { rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); } - auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight); + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight, mode == TYPE_IMROPE); sin_theta_node = sin_cos.first; cos_theta_node = sin_cos.second; } @@ -65,11 +76,7 @@ OutputVector translate_rope(const NodeContext & context) { } } - const int mode = op_params[2]; - constexpr int ROPE_TYPE_NORMAL = 0; - constexpr int ROPE_TYPE_NEOX = 2; - - if (mode == ROPE_TYPE_NORMAL) { + if (mode == TYPE_NORMAL) { auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); @@ -97,7 +104,7 @@ OutputVector translate_rope(const NodeContext & context) { auto data_shape = ov::op::v0::Constant::create( ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); res = std::make_shared(stack, data_shape, false); - } else if (mode == ROPE_TYPE_NEOX) { + } else if (mode == TYPE_NEOX) { auto data_split = std::make_shared( data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); Output slice_data_node_0 = data_split->outputs()[0]; @@ -112,6 +119,25 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(slice_data_node_1, cos_theta_node)); res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1); + } else if (mode == TYPE_IMROPE) { + int64_t n_dims = data_node->get_shape()[3]; + auto cos_sin_shape = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1,-1,1,(n_dims >> 1)}); + auto cos_reshaped = std::make_shared(cos_theta_node, cos_sin_shape, true); + auto sin_reshaped = std::make_shared(sin_theta_node, cos_sin_shape, true); + + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}); + auto split_a = std::make_shared(data_node, split_axis, 2); + auto x0 = split_a->output(0); + auto x1 = split_a->output(1); + auto mul_a = std::make_shared(x0, cos_reshaped); + auto mul_b = std::make_shared(x1, sin_reshaped); + auto sub = std::make_shared(mul_a, mul_b); + + auto mul_c = std::make_shared(x0, sin_reshaped); + auto mul_d = std::make_shared(x1, cos_reshaped); + auto add = std::make_shared(mul_c, mul_d); + + res = std::make_shared(ov::OutputVector{sub, add}, 3); } return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp new file mode 100644 index 00000000000..d1e9efc33a5 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp @@ -0,0 +1,25 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_gelu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto res = std::make_shared(input); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index beadafe8103..1385539279c 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -31,6 +31,7 @@ std::unordered_map get_supported_ops() { {"GGML_OP_SOFT_MAX", op::translate_soft_max }, {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_GELU", op::translate_unary_gelu }, {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_OP_VIEW", op::translate_view }, {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h index 37f763117aa..f546796d2ee 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.h +++ b/ggml/src/ggml-openvino/openvino/op_table.h @@ -21,6 +21,7 @@ GGML_OP_CONVERTER(translate_rms_norm); GGML_OP_CONVERTER(translate_rope); GGML_OP_CONVERTER(translate_scale); GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_unary_gelu); GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_view); diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp deleted file mode 100644 index ed2a3ab6d1b..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include "eliminate_zp.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -EliminateZeroPoints::EliminateZeroPoints() { - // Find pattern: - // (Multiply Any(scale) - // (Subtract (Convert Constant(data))) - // (Convert Constant(zero_point))) - // where zero_point is a scalar - // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val - // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant - - auto m_data_constant = ov::pass::pattern::wrap_type(); - auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); - - auto m_zp_constant = ov::pass::pattern::wrap_type(); - auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); - - auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); - auto m_scale = ov::pass::pattern::any_input(); - auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); - - const auto callback = [=](ov::pass::pattern::Matcher & m) { - const auto & pattern_map = m.get_pattern_value_map(); - - auto multiply_node = - std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); - auto subtract_node = - std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); - auto data_constant = - std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); - auto zp_constant = - std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); - - if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { - return false; - } - - if (ov::shape_size(zp_constant->get_shape()) != 1) { - return false; - } - - auto data_type = data_constant->get_element_type(); - auto zp_data = zp_constant->cast_vector(); - - if (zp_data.empty()) { - return false; - } - - int zp_value = zp_data[0]; - - bool should_eliminate = false; - ov::element::Type target_type; - - if (data_type == ov::element::u4 && zp_value == 8) { - should_eliminate = true; - target_type = ov::element::i4; - } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { - should_eliminate = true; - target_type = ov::element::i8; - } - - if (!should_eliminate) { - return false; - } - - auto data_shape = data_constant->get_shape(); - size_t total_elements = ov::shape_size(data_shape); - - std::shared_ptr new_constant; - - // TODO improve performance - if (data_type == ov::element::u4) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } else if (data_type == ov::element::u8) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&, zp_value](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } - - auto new_convert = - std::make_shared(new_constant, subtract_node->get_output_element_type(0)); - ov::replace_node(subtract_node, new_convert); - - return true; - }; - - register_matcher( - std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), - callback); -} - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h deleted file mode 100644 index edd3cd718d9..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +++ /dev/null @@ -1,17 +0,0 @@ -#include "openvino/pass/matcher_pass.hpp" - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -class EliminateZeroPoints : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") - EliminateZeroPoints(); -}; - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp new file mode 100644 index 00000000000..f051891c481 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { + +/** + * @brief Holds weightless caching attributes of a single constant. + * + * WeightlessCacheAttribute class represents runtime info attribute that holds + * the values of original size of the constant in bytes and the binary offset of the + * constant's data in the weights file used by the weightless caching mechanism. It's + * not copyable in case the data was changed (the original node was replaced by a new + * one produced during the tranformation pipeline) - in that case weightless caching + * can't be used for that constant. + */ +class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute { +public: + OPENVINO_RTTI("WeightlessCacheAttribute", "0", RuntimeAttribute) + + WeightlessCacheAttribute() = delete; + + WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype) + : original_size(original_size), + bin_offset(bin_offset), + original_dtype(original_dtype) {} + + bool is_copyable() const override; + + size_t original_size; + size_t bin_offset; + ov::element::Type original_dtype; +}; + +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 23a1dea2496..0f68a1f5062 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -3,15 +3,16 @@ #include "ggml-openvino/openvino/node_context.h" #include "ggml-openvino/openvino/utils.h" #include "input_model.h" -#include "pass/eliminate_zp.h" #include "pass/mark_decompression_convert_constant_folding.h" #include "pass/squeeze_matmul.h" +#include "rt_info/weightless_caching_attributes.hpp" #include #include #include #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include namespace ov { namespace frontend { @@ -240,6 +240,31 @@ std::shared_ptr TranslateSession::translate_graph(const frontend::InputMo resulting_model = std::make_shared(results, used_params); apply_transformations(resulting_model); + + // Set WeightlessCacheAttribute on large constants to avoid unnecessary memory copies + // in the NPUW plugin. Without this attribute, NPUW's LazyTensor constructor + // (lazy_tensor.cpp, op::Const::Const) will memcpy every constant "in case export + // occurs", doubling memory usage per compile_model call. + // + // The bin_offset field serves as a unique key (not a real file offset) — this is + // the same convention the GPU plugin uses for non-IR models (see + // Plugin::set_weightless_cache_attributes in intel_gpu/src/plugin/plugin.cpp). + // Each constant must have a distinct bin_offset, otherwise GPU's weightless cache + // import will map multiple constants to the same data. + // + // Small constants (< 16 elements) are excluded since they may be introduced by + // optimization patterns and the overhead is negligible. + size_t offset = 0; + for (auto & node : resulting_model->get_ordered_ops()) { + if (auto cnst = ov::as_type_ptr(node); + cnst && cnst->get_byte_size() / cnst->get_element_type().size() >= 16) { + auto & rt_info = cnst->get_rt_info(); + if (rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()) == rt_info.end()) { + rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = + ov::WeightlessCacheAttribute(cnst->get_byte_size(), offset++, cnst->get_element_type()); + } + } + } return resulting_model; } @@ -257,7 +282,6 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptris_static()) { - manager.register_pass(); manager.register_pass(); } manager.run_passes(model); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index 65356a51b51..0baaf88e17a 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -87,8 +89,11 @@ ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], fl auto ramp_y = std::make_shared(std::make_shared(dim_ids, corr_low), denom); auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f); + // rope_yarn_ramp returns (1 - clamp(y)), so invert before scaling + auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + auto ramp_inverted = std::make_shared(one, ramp_clamped); auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); - auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node); + auto ramp_mix = std::make_shared(ramp_inverted, ext_factor_node); return ramp_mix; } @@ -115,6 +120,7 @@ void ggml_rope_yarn_corr_dims(int n_dims, std::pair, ov::Output> make_sin_cos(int32_t * rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight, + bool imrope, bool stateful) { if (stateful) { inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); @@ -122,6 +128,13 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params auto pos_perm = std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); inp_pos = std::make_shared(inp_pos, pos_perm); + } else if (imrope) { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 0, 0, 4, -1}); + inp_pos = std::make_shared(inp_pos, pos_shape, true); + auto pos_transpose_shape = + std::make_shared(ov::element::i64, ov::Shape{5}, std::vector{0, 1, 2, 4, 3}); + inp_pos = std::make_shared(inp_pos, pos_transpose_shape); } else { inp_pos = std::make_shared(inp_pos, ov::element::f32); auto pos_perm = @@ -136,6 +149,7 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params float beta_fast; float beta_slow; const int n_dims = rope_params[1]; + const size_t n_dims_half = n_dims >> 1; const int n_ctx_orig = rope_params[4]; memcpy(&freq_base, rope_params + 5, sizeof(float)); memcpy(&freq_scale, rope_params + 6, sizeof(float)); @@ -146,57 +160,74 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params const float theta_scale = powf(freq_base, -2.0f / n_dims); - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - std::vector factor(n_dims / 2); - factor[0] = 1.0f; - for (size_t i = 1; i < factor.size(); i++) { - factor[i] = theta_scale * factor[i - 1]; - } + std::vector factor(n_dims_half); Output freq_factors; - if (stateful) { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); - } else { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); - } - if (rope_freqs_weight) { - freq_factors = std::make_shared(freq_factors, rope_freqs_weight); - } - - auto theta_extrap = std::make_shared(freq_factors, inp_pos); - auto theta_interp = std::make_shared( - theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); Output theta; float mscale = attn_factor; - if (ext_factor == 0.0f) { - theta = theta_interp; + if (imrope) { + std::vector gather_indices(n_dims_half); + for (size_t j = 0; j < n_dims_half; j++) { + gather_indices[j] = j % 3; + factor[j] = std::pow(theta_scale, j); + } + auto gather_indices_const = + std::make_shared(ov::element::i64, ov::Shape{n_dims_half}, gather_indices); + auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {4}); + inp_pos = std::make_shared(inp_pos, gather_indices_const, gather_axis); + auto factor_const = std::make_shared(ov::element::f32, ov::Shape{n_dims_half}, factor); + theta = std::make_shared(inp_pos, factor_const); } else { - auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); - Output one; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } if (stateful) { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); } else { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } + if (rope_freqs_weight) { + freq_factors = std::make_shared(freq_factors, rope_freqs_weight); } - auto one_minus_ramp = std::make_shared(one, ramp_mix); - theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), - std::make_shared(theta_extrap, ramp_mix)); - mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + auto theta_extrap = std::make_shared(freq_factors, inp_pos); + auto theta_interp = std::make_shared( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + Output one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } + auto one_minus_ramp = std::make_shared(one, ramp_mix); + + theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), + std::make_shared(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } } Output cos_theta = std::make_shared(theta); Output sin_theta = std::make_shared(theta); - auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + if (!imrope) { + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared(cos_theta, mscale_node); + sin_theta = std::make_shared(sin_theta, mscale_node); + } - cos_theta = std::make_shared(cos_theta, mscale_node); - sin_theta = std::make_shared(sin_theta, mscale_node); return std::make_pair(sin_theta, cos_theta); } diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h index 88dcad4c906..767dd4c53ea 100644 --- a/ggml/src/ggml-openvino/openvino/utils.h +++ b/ggml/src/ggml-openvino/openvino/utils.h @@ -67,6 +67,7 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std:: std::pair, ov::Output> make_sin_cos(int32_t* rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight = nullptr, + bool imrope = false, bool stateful = false); ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 1b553a0de00..998ef7c9eb4 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -81,8 +81,8 @@ ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder, enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr r_ctx) { auto & core = ov_singleton_core(); const auto & config = ggml_openvino_get_compile_config(); - auto device = r_ctx->device; - bool stateful = r_ctx->stateful; + const auto & device = r_ctx->device; + const auto & stateful = r_ctx->stateful; static auto is_static = false; if (is_naive(cgraph)) { @@ -106,14 +106,26 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< int64_t infer_end_time; { - std::lock_guard lock(r_ctx->ov_compute_mutex); + std::shared_ptr entry; + ModelParams old_m_params; - auto it = r_ctx->decoder_cache.find(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); - cache_hit = it != r_ctx->decoder_cache.end(); - ModelParams old_m_params; if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_dynamically(m_params); } @@ -126,7 +138,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< ggml_decoder->update_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = r_ctx->infer_request_cache.at(key); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -170,7 +185,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -199,8 +217,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } compile_end_time = ggml_time_us(); infer_request = std::make_shared(compiled_model.create_infer_request()); - r_ctx->infer_request_cache[key] = infer_request; - r_ctx->decoder_cache[key] = ggml_decoder; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -210,8 +227,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< for (const auto & ov_output : model->get_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache[key] = infer_request; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -224,8 +246,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names; + std::vector ov_output_names; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names = r_ctx->ov_input_names_cache[key]; + ov_output_names = r_ctx->ov_output_names_cache[key]; + } for (size_t i = 0; i < ov_input_names.size(); i++) { auto param_name = ov_input_names[i]; @@ -306,12 +333,26 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrdecoder_cache.find(key); - - cache_hit = it != r_ctx->decoder_cache.end(); + std::shared_ptr entry; ModelParams old_m_params; + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); + if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_statically(m_params); } @@ -325,14 +366,21 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrupdate_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = + is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + } decoder_end_time = ggml_time_us(); conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); - r_ctx->infer_request_cache_prefill.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + r_ctx->infer_request_cache_prefill.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -372,16 +420,14 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer_request_cache_prefill[key] = - std::make_shared(compiled_model_prefill.create_infer_request()); - r_ctx->infer_request_cache[key] = - std::make_shared(compiled_model_decode.create_infer_request()); + auto infer_request_prefill = std::make_shared(compiled_model_prefill.create_infer_request()); + auto infer_request_decode = std::make_shared(compiled_model_decode.create_infer_request()); compile_end_time = ggml_time_us(); model = is_prefill ? model_prefill : model_decode; ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key]; - r_ctx->decoder_cache[key] = ggml_decoder; + infer_request = is_prefill ? infer_request_prefill : infer_request_decode; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -391,18 +437,29 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache_prefill[key] = infer_request_prefill; + r_ctx->infer_request_cache[key] = infer_request_decode; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names_local; + std::vector ov_output_names_local; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names_local = r_ctx->ov_input_names_cache[key]; + ov_output_names_local = r_ctx->ov_output_names_cache[key]; + } if (is_prefill) { auto inp_len = inp_pos->ne[0]; for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index); infer_request->set_input_tensor(i, input_tensor); @@ -412,8 +469,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -421,16 +478,16 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer(); if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { - for (size_t i = 0; i < ov_output_names.size(); i++) { + for (size_t i = 0; i < ov_output_names_local.size(); i++) { const auto output_tensor = infer_request->get_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } infer_end_time = ggml_time_us(); } else { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name); infer_request->set_input_tensor(i, input_tensor); @@ -440,8 +497,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -450,9 +507,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 656573d1389..2c72e33c352 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -3,12 +3,15 @@ #include "ggml-impl.h" #include +#include #include #include +#include #include #include #include #include +#include #include struct graph_key { @@ -40,11 +43,17 @@ struct graph_key_hash { } }; +struct decoder_runtime_ctx { + decoder_runtime_ctx(std::shared_ptr mutex) : mutex(std::move(mutex)) {} + std::shared_ptr mutex; + std::shared_ptr ptr; +}; + struct ov_runtime_context { - std::mutex ov_compute_mutex; + mutable std::mutex ctx_mutex; std::string device; bool stateful; - std::unordered_map, graph_key_hash> decoder_cache; + std::unordered_map, graph_key_hash> decoder_cache; std::unordered_map, graph_key_hash> infer_request_cache; std::unordered_map, graph_key_hash> infer_request_cache_prefill; std::unordered_map, graph_key_hash> ov_input_names_cache; @@ -53,11 +62,22 @@ struct ov_runtime_context { // Simultanous stateful inference request support to be added. size_t stateful_kv_size; std::map kv_state_input_name_map; + std::atomic backend_count; ov_runtime_context() : device("CPU"), stateful(false), - stateful_kv_size(0) {} + stateful_kv_size(0), + backend_count(0) {} + + void clear_caches() { + std::lock_guard lock(ctx_mutex); + decoder_cache.clear(); + infer_request_cache.clear(); + infer_request_cache_prefill.clear(); + ov_input_names_cache.clear(); + ov_output_names_cache.clear(); + } }; enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend); diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index 8671ce5ceaf..40e11fead63 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -2,6 +2,7 @@ message(STATUS "Using RPC backend") ggml_add_backend_library(ggml-rpc ggml-rpc.cpp + transport.cpp ) if (WIN32) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 017ef0af360..2ded7397868 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include "transport.h" #include #include @@ -12,35 +13,11 @@ #include #include #include -#ifdef _WIN32 -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -# include -# include -# include -# include -#endif #include #include #include #include -#ifdef GGML_RPC_RDMA -# include -# include -# ifndef _WIN32 -# include -# endif -#endif // GGML_RPC_RDMA - static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); #define LOG_DBG(...) \ @@ -49,128 +26,6 @@ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); namespace fs = std::filesystem; -static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB - -#ifdef _WIN32 -typedef SOCKET sockfd_t; -using ssize_t = __int64; -#else -typedef int sockfd_t; -#endif - -// cross-platform socket - -#ifdef GGML_RPC_RDMA -static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) -static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB -static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes -using rdma_gid_t = std::array; - -struct rdma_conn { - struct ibv_context * ctx = nullptr; - struct ibv_pd * pd = nullptr; - struct ibv_cq * scq = nullptr; // send completions - struct ibv_cq * rcq = nullptr; // recv completions - struct ibv_qp * qp = nullptr; - - void * tx_buf = nullptr; - struct ibv_mr * tx_mr = nullptr; - - void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous - struct ibv_mr * rx_mr = nullptr; - int rx_head = 0; - - uint32_t max_inline = 0; - - uint8_t * rx_slot(int i) const { - return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; - } - - bool post_rx(int i) { - struct ibv_sge sge = {}; - sge.addr = (uintptr_t)rx_slot(i); - sge.length = RDMA_CHUNK; - sge.lkey = rx_mr->lkey; - struct ibv_recv_wr wr = {}, * bad = nullptr; - wr.wr_id = (uint64_t)i; - wr.sg_list = &sge; - wr.num_sge = 1; - return ibv_post_recv(qp, &wr, &bad) == 0; - } - - ~rdma_conn() { - if (tx_mr) ibv_dereg_mr(tx_mr); - if (rx_mr) ibv_dereg_mr(rx_mr); - free(tx_buf); - free(rx_buf); - if (qp) ibv_destroy_qp(qp); - if (scq) ibv_destroy_cq(scq); - if (rcq) ibv_destroy_cq(rcq); - if (pd) ibv_dealloc_pd(pd); - if (ctx) ibv_close_device(ctx); - } -}; - -// Local RDMA parameters captured during the probe phase and later consumed -// by rdma_activate() after the remote side's caps arrive via HELLO. -struct rdma_local_info { - uint32_t qpn = 0; - uint32_t psn = 0; - uint8_t gid[RDMA_GID_SIZE] = {}; - uint8_t ib_port = 0; - int gid_idx = 0; - enum ibv_mtu path_mtu = IBV_MTU_1024; -}; -#endif // GGML_RPC_RDMA - -// conn_caps size for transport-agnostic capability exchange -static constexpr size_t RPC_CONN_CAPS_SIZE = 24; - -// conn_caps RDMA layout helper -#ifdef GGML_RPC_RDMA -struct rdma_caps { - uint32_t qpn; - uint32_t psn; - uint8_t gid[RDMA_GID_SIZE]; -}; -static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); -#endif // GGML_RPC_RDMA - -// Forward declarations for transport function pointers -struct socket_t; -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size); -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size); - -struct socket_t { - sockfd_t fd; - bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; - bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; -#ifdef GGML_RPC_RDMA - std::unique_ptr rdma; - rdma_local_info rdma_local = {}; -#endif // GGML_RPC_RDMA - socket_t(sockfd_t fd) : fd(fd) {} - ~socket_t() { -#ifdef GGML_RPC_RDMA - rdma.reset(); -#endif // GGML_RPC_RDMA - LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); -#ifdef _WIN32 - if (fd != INVALID_SOCKET) closesocket(this->fd); -#else - if (fd >= 0) close(this->fd); -#endif - } - - // Advertise local transport capabilities into conn_caps. - // May probe RDMA and store the probe on this socket for update_caps. - void get_caps(uint8_t * caps); - - // Activate transport upgrade based on remote conn_caps using the probe - // previously stored by get_caps. - void update_caps(const uint8_t * remote_caps); -}; - // macro for nicer error messages on server crash #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") @@ -403,540 +258,27 @@ static uint64_t fnv_hash(const uint8_t * data, size_t len) { return hash; } -static std::shared_ptr make_socket(sockfd_t fd) { -#ifdef _WIN32 - if (fd == INVALID_SOCKET) { - return nullptr; - } -#else - if (fd < 0) { - return nullptr; - } -#endif - return std::make_shared(fd); -} - -static bool set_no_delay(sockfd_t sockfd) { - int flag = 1; - // set TCP_NODELAY to disable Nagle's algorithm - int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static bool set_reuse_addr(sockfd_t sockfd) { - int flag = 1; - int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static std::shared_ptr socket_connect(const char * host, int port) { - struct sockaddr_in addr; - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock_ptr = make_socket(sockfd); - if (sock_ptr == nullptr) { - return nullptr; - } - if (!set_no_delay(sockfd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - struct hostent * server = gethostbyname(host); - if (server == NULL) { - GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); - return nullptr; - } - memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); - if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { - return nullptr; - } - return sock_ptr; -} - -static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { - auto client_socket_fd = accept(srv_sockfd, NULL, NULL); - auto client_socket = make_socket(client_socket_fd); - if (client_socket == nullptr) { - return nullptr; - } - if (!set_no_delay(client_socket_fd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - return client_socket; -} - -static std::shared_ptr create_server_socket(const char * host, int port) { - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock = make_socket(sockfd); - if (sock == nullptr) { - return nullptr; - } - if (!set_reuse_addr(sockfd)) { - GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); - return nullptr; - } - if (inet_addr(host) == INADDR_NONE) { - GGML_LOG_ERROR("Invalid host address: %s\n", host); - return nullptr; - } - struct sockaddr_in serv_addr; - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = inet_addr(host); - serv_addr.sin_port = htons(port); - - if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { - return nullptr; - } - if (listen(sockfd, 1) < 0) { - return nullptr; - } - return sock; -} - -static bool send_data(sockfd_t sockfd, const void * data, size_t size) { - size_t bytes_sent = 0; - while (bytes_sent < size) { - size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); - ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0); - if (n < 0) { - GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", - bytes_sent, size_to_send); - return false; - } - bytes_sent += (size_t)n; - } - return true; -} - -static bool recv_data(sockfd_t sockfd, void * data, size_t size) { - size_t bytes_recv = 0; - while (bytes_recv < size) { - size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); - ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0); - if (n < 0) { - GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", - bytes_recv, size_to_recv); - return false; - } - if (n == 0) { - LOG_DBG("recv returned 0 (peer closed?)\n"); - return false; - } - bytes_recv += (size_t)n; - } - return true; -} - -// TCP transport implementations (for function-pointer dispatch) - -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) { - return send_data(sock->fd, data, size); -} - -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) { - return recv_data(sock->fd, data, size); -} - -// RDMA transport (performance-optimized, auto-negotiated) - -#ifdef GGML_RPC_RDMA - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size); -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size); - -static inline bool tcp_peer_closed(int fd) { - if (fd < 0) return false; -#ifndef _WIN32 - struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; - int r = poll(&pfd, 1, 0); - return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); -#else - return false; -#endif -} - -static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) { - for (uint64_t s = 0; ; s++) { - int n = ibv_poll_cq(cq, 1, wc); - if (n > 0) { - if (wc->status != IBV_WC_SUCCESS) { - GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", - wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); - } - return wc->status == IBV_WC_SUCCESS; - } - if (n < 0) return false; - if ((s & 0xFFFFF) == 0 && s > 0) { - if (tcp_peer_closed(tcp_fd)) { - return false; - } - } - } -} - -static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) { - const uint8_t * src = (const uint8_t *)data; - size_t rem = size; - while (rem > 0) { - size_t chunk = std::min(rem, RDMA_CHUNK); - - struct ibv_sge sge = {}; - struct ibv_send_wr wr = {}, * bad = nullptr; - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - - if (chunk <= c->max_inline) { - sge.addr = (uintptr_t)src; - sge.length = chunk; - wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; - } else { - memcpy(c->tx_buf, src, chunk); - sge.addr = (uintptr_t)c->tx_buf; - sge.length = chunk; - sge.lkey = c->tx_mr->lkey; - wr.send_flags = IBV_SEND_SIGNALED; - } - - if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; - struct ibv_wc wc; - if (!rdma_poll(c->scq, &wc, tcp_fd)) return false; - - src += chunk; - rem -= chunk; - } - return true; -} - - -static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) { - uint8_t * dst = (uint8_t *)data; - size_t rem = size; - while (rem > 0) { - struct ibv_wc wc; - if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false; - - int slot = (int)wc.wr_id; - size_t got = wc.byte_len; - memcpy(dst, c->rx_slot(slot), got); - - if (!c->post_rx(slot)) return false; - - dst += got; - rem -= got; - } - return true; -} - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) { - return rdma_send(sock->rdma.get(), data, size, sock->fd); -} - -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) { - return rdma_recv(sock->rdma.get(), data, size, sock->fd); -} - -// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. -// Used to match the socket's local IP against the kernel's GID table so that -// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: -// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) -// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) -// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is -// Returns std::nullopt on unsupported family or getsockname failure. -static std::optional rdma_build_target_gid(sockfd_t tcp_fd) { - sockaddr_storage addr = {}; - socklen_t addr_len = sizeof(addr); - if (getsockname(tcp_fd, reinterpret_cast(&addr), &addr_len) != 0) { - return std::nullopt; - } - rdma_gid_t target = {}; - if (addr.ss_family == AF_INET) { - const auto * a = reinterpret_cast(&addr); - target[10] = 0xff; - target[11] = 0xff; - memcpy(&target[12], &a->sin_addr, 4); - return target; - } - if (addr.ss_family == AF_INET6) { - const auto * a = reinterpret_cast(&addr); - memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); - return target; - } - return std::nullopt; -} - -static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) { - const char * dev_env = std::getenv("GGML_RDMA_DEV"); - const char * gid_env = std::getenv("GGML_RDMA_GID"); - - auto target_gid = rdma_build_target_gid(tcp_fd); - if (!target_gid) { - return nullptr; - } - - const uint8_t ib_port = 1; - int num_devs = 0; - ibv_device ** devs = ibv_get_device_list(&num_devs); - if (!devs || num_devs == 0) return nullptr; - - ibv_context * ibctx = nullptr; - const char * matched_dev = nullptr; - int gid_idx = gid_env ? atoi(gid_env) : -1; - int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB - - for (int d = 0; d < num_devs; d++) { - const char * dn = ibv_get_device_name(devs[d]); - if (dev_env && strcmp(dev_env, dn) != 0) continue; - - ibv_context * ctx = ibv_open_device(devs[d]); - if (!ctx) continue; - - ibv_port_attr pa; - if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } - - int found_gid = gid_idx; - int found_version = IBV_GID_TYPE_IB; - if (found_gid < 0) { - // Find a GID on this port whose bytes equal the local TCP address - // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 - // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths - // are avoided. ibv_query_gid_ex returns gid+type in one call. - int v2_idx = -1; - int v1_idx = -1; - for (int i = 0; i < pa.gid_tbl_len; i++) { - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; - if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; - if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { - v2_idx = i; - } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { - v1_idx = i; - } - } - if (v2_idx >= 0) { - found_gid = v2_idx; - found_version = IBV_GID_TYPE_ROCE_V2; - } else if (v1_idx >= 0) { - found_gid = v1_idx; - found_version = IBV_GID_TYPE_ROCE_V1; - } - } else { - // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { - found_version = entry.gid_type; - } - } - if (found_gid >= 0) { - ibctx = ctx; - gid_idx = found_gid; - gid_version = found_version; - matched_dev = dn; - out->path_mtu = pa.active_mtu; - break; - } - ibv_close_device(ctx); - } - ibv_free_device_list(devs); - if (!ibctx) return nullptr; - - out->ib_port = ib_port; - out->gid_idx = gid_idx; - - // unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(), - // so each failure path is a plain `return nullptr;`. - auto c = std::make_unique(); - c->ctx = ibctx; - - c->pd = ibv_alloc_pd(ibctx); - if (!c->pd) return nullptr; - - c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); - c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); - if (!c->scq || !c->rcq) return nullptr; - - ibv_qp_init_attr qia = {}; - qia.send_cq = c->scq; - qia.recv_cq = c->rcq; - qia.qp_type = IBV_QPT_RC; - qia.cap.max_send_wr = 4; - qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; - qia.cap.max_send_sge = 1; - qia.cap.max_recv_sge = 1; - qia.cap.max_inline_data = 256; - - c->qp = ibv_create_qp(c->pd, &qia); - if (!c->qp) return nullptr; - c->max_inline = qia.cap.max_inline_data; - - c->tx_buf = aligned_alloc(4096, RDMA_CHUNK); - c->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); - if (!c->tx_buf || !c->rx_buf) return nullptr; - - c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); - c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); - if (!c->tx_mr || !c->rx_mr) return nullptr; - - ibv_gid local_gid; - if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr; - - out->qpn = c->qp->qp_num; - out->psn = c->qp->qp_num & 0xffffff; - memcpy(out->gid, &local_gid, RDMA_GID_SIZE); - - const char * ver_str = ""; - if (gid_version == IBV_GID_TYPE_ROCE_V2) { - ver_str = " RoCEv2"; - } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { - ver_str = " RoCEv1"; - } - GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", - matched_dev, gid_idx, ver_str, out->qpn, c->max_inline); - return c.release(); -} - -// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. -// On success, the connection is live and ready for rdma_send/rdma_recv. -static bool rdma_activate(rdma_conn * c, const rdma_local_info * local, - uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { - // RESET -> INIT - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_INIT; - a.port_num = local->ib_port; - a.pkey_index = 0; - a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - return false; - } - } - - for (int i = 0; i < RDMA_RX_DEPTH; i++) { - if (!c->post_rx(i)) return false; - } - - // INIT -> RTR - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTR; - a.path_mtu = local->path_mtu; - a.dest_qp_num = remote_qpn; - a.rq_psn = remote_psn; - a.max_dest_rd_atomic = 1; - a.min_rnr_timer = 1; - a.ah_attr.is_global = 1; - memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); - a.ah_attr.grh.hop_limit = 1; - a.ah_attr.grh.sgid_index = local->gid_idx; - a.ah_attr.dlid = 0; - a.ah_attr.port_num = local->ib_port; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | - IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { - return false; - } - } - - // RTR -> RTS - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTS; - a.timeout = 14; - a.retry_cnt = 7; - a.rnr_retry = 7; - a.sq_psn = local->psn; - a.max_rd_atomic = 1; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | - IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { - return false; - } - } - - GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", - local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH); - return true; -} - -#endif // GGML_RPC_RDMA - -// --------------------------------------------------------------------------- -// socket_t transport capability methods -// --------------------------------------------------------------------------- - -void socket_t::get_caps(uint8_t * caps) { - memset(caps, 0, RPC_CONN_CAPS_SIZE); -#ifdef GGML_RPC_RDMA - rdma_local = {}; - rdma.reset(rdma_probe(fd, &rdma_local)); - if (rdma) { - rdma_caps rc = {}; - rc.qpn = rdma_local.qpn; - rc.psn = rdma_local.psn; - memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); - memcpy(caps, &rc, sizeof(rc)); - } -#endif // GGML_RPC_RDMA -} - -void socket_t::update_caps(const uint8_t * remote_caps) { -#ifdef GGML_RPC_RDMA - if (!rdma) { - return; - } - rdma_caps rc = {}; - memcpy(&rc, remote_caps, sizeof(rc)); - if (rc.qpn == 0) { - rdma.reset(); - return; - } - if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) { - fn_send = rdma_send_impl; - fn_recv = rdma_recv_impl; - } else { - GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); - rdma.reset(); - } -#else - (void)remote_caps; -#endif // GGML_RPC_RDMA -} - -// unified transport dispatch (via function pointers) - -static bool send_data(socket_t * sock, const void * data, size_t size) { - return sock->fn_send(sock, data, size); -} - -static bool recv_data(socket_t * sock, void * data, size_t size) { - return sock->fn_recv(sock, data, size); -} - -static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) { - if (!send_data(sock, &msg_size, sizeof(msg_size))) { +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { return false; } - return send_data(sock, msg, msg_size); + return sock->send_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) { +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sock, msg, msg_size); + return sock->recv_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, std::vector & input) { +static bool recv_msg(socket_ptr sock, std::vector & input) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } try { @@ -945,7 +287,7 @@ static bool recv_msg(socket_t * sock, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sock, input.data(), size); + return sock->recv_data(input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -964,15 +306,15 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // No response -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) { + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock.get(), &input_size, sizeof(input_size))) { + if (!sock->send_data(&input_size, sizeof(input_size))) { return false; } - if (!send_data(sock.get(), input, input_size)) { + if (!sock->send_data(input, input_size)) { return false; } return true; @@ -980,18 +322,18 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } uint64_t out_size; - if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { + if (!sock->recv_data(&out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock.get(), output, output_size)) { + if (!sock->recv_data(output, output_size)) { return false; } return true; @@ -1025,7 +367,6 @@ static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); static std::unordered_map> sockets; - static bool initialized = false; auto it = sockets.find(endpoint); if (it != sockets.end()) { @@ -1040,19 +381,10 @@ static std::shared_ptr get_socket(const std::string & endpoint) { return nullptr; } -#ifdef _WIN32 - if (!initialized) { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - initialized = true; + if (!rpc_transport_init()) { + return nullptr; } -#else - GGML_UNUSED(initialized); -#endif - auto sock = socket_connect(host.c_str(), port); + auto sock = socket_t::connect(host.c_str(), port); if (sock == nullptr) { return nullptr; } @@ -2110,10 +1442,10 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - socket_t * sockfd) { + socket_ptr sock) { rpc_server server(backends, cache_dir); uint8_t cmd; - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { return; } if (cmd != RPC_CMD_HELLO) { @@ -2123,7 +1455,7 @@ static void rpc_serve_client(const std::vector & backends, const // Read input_size and validate protocol version uint64_t hello_input_size; - if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) { + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { return; } @@ -2134,24 +1466,22 @@ static void rpc_serve_client(const std::vector & backends, const } rpc_msg_hello_req req = {}; - if (!recv_data(sockfd, &req, sizeof(req))) { + if (!sock->recv_data(&req, sizeof(req))) { return; } rpc_msg_hello_rsp rsp = {}; server.hello(rsp); - // Advertise server transport capabilities based on client's caps - sockfd->get_caps(rsp.conn_caps); - - if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { return; } // Activate transport upgrade using client's caps - sockfd->update_caps(req.conn_caps); + sock->update_caps(req.conn_caps); while (true) { - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { break; } if (cmd >= RPC_CMD_COUNT) { @@ -2165,115 +1495,115 @@ static void rpc_serve_client(const std::vector & backends, const return; } case RPC_CMD_DEVICE_COUNT: { - if (!recv_msg(sockfd, nullptr, 0)) { + if (!recv_msg(sock, nullptr, 0)) { return; } rpc_msg_device_count_rsp response; response.device_count = backends.size(); - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; if (!server.alloc_buffer(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALLOC_SIZE: { rpc_msg_get_alloc_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alloc_size_rsp response; if (!server.get_alloc_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALIGNMENT: { rpc_msg_get_alignment_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; if (!server.get_alignment(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { rpc_msg_get_max_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; if (!server.get_max_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_BUFFER_GET_BASE: { rpc_msg_buffer_get_base_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_buffer_get_base_rsp response; if (!server.buffer_get_base(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_FREE_BUFFER: { rpc_msg_free_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.free_buffer(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_BUFFER_CLEAR: { rpc_msg_buffer_clear_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.buffer_clear(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_SET_TENSOR: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.set_tensor(input)) { @@ -2283,62 +1613,62 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_SET_TENSOR_HASH: { rpc_msg_set_tensor_hash_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; if (!server.set_tensor_hash(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; - if (!recv_msg(sockfd, &request,sizeof(request))) { + if (!recv_msg(sock, &request,sizeof(request))) { return; } if (!server.init_tensor(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } std::vector response; if (!server.get_tensor(request, response)) { return; } - if (!send_msg(sockfd, response.data(), response.size())) { + if (!send_msg(sock, response.data(), response.size())) { return; } break; } case RPC_CMD_COPY_TENSOR: { rpc_msg_copy_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_copy_tensor_rsp response; if (!server.copy_tensor(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GRAPH_COMPUTE: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.graph_compute(input)) { @@ -2348,7 +1678,7 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GRAPH_RECOMPUTE: { rpc_msg_graph_recompute_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.graph_recompute(request)) { @@ -2358,14 +1688,14 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GET_DEVICE_MEMORY: { rpc_msg_get_device_memory_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_device_memory_rsp response; if (!server.get_device_memory(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; @@ -2424,36 +1754,28 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir #else printf(" transport : TCP\n"); #endif // GGML_RPC_RDMA -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - fprintf(stderr, "WSAStartup failed: %d\n", res); - return; - } + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; } -#endif - auto server_socket = create_server_socket(host.c_str(), port); + auto server_socket = socket_t::create_server(host.c_str(), port); if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return; } while (true) { - auto client_socket = socket_accept(server_socket->fd); + auto client_socket = server_socket->accept(); if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return; } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket.get()); + rpc_serve_client(backends, cache_dir, client_socket); printf("Client connection closed\n"); fflush(stdout); } -#ifdef _WIN32 - WSACleanup(); -#endif + rpc_transport_shutdown(); for (auto backend : backends) { ggml_backend_free(backend); } diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 00000000000..a728152421f --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 00000000000..73b85cc530a --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct socket_t; +typedef std::shared_ptr socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr p); + std::unique_ptr pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fd84c917853..0101b27640a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -28,6 +28,13 @@ namespace syclexp = sycl::ext::oneapi::experimental; +#if defined(__INTEL_LLVM_COMPILER) && __has_include() + #include + #ifndef GGML_SYCL_HAS_BF16 + #define GGML_SYCL_HAS_BF16 + #endif +#endif + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index f3c521b45f6..67b9c06f3e4 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -2,13 +2,6 @@ #include "dequantize.hpp" #include "presets.hpp" -#if defined(__INTEL_LLVM_COMPILER) - #if __has_include() - #include - #define GGML_SYCL_HAS_BF16 - #endif -#endif - template static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { @@ -767,6 +760,22 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { } +#ifdef GGML_SYCL_HAS_BF16 +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_sycl; + case GGML_TYPE_F16: + return convert_unary_sycl; + case GGML_TYPE_BF16: + return convert_unary_sycl; + default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); + return nullptr; + } +} +#endif + to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 6e621f2154d..8de79d10ff6 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -23,6 +23,11 @@ typedef to_t_sycl_t to_fp16_sycl_t; to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); +#ifdef GGML_SYCL_HAS_BF16 +typedef to_t_sycl_t to_bf16_sycl_t; +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst); +#endif + // Nc = Non-contiguous template using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, @@ -35,15 +40,19 @@ template inline dst_t ggml_sycl_cast(src_t x) { if constexpr (std::is_same_v) { return x; +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v) { return sycl::ext::oneapi::bfloat16(float(x)); } else if constexpr (std::is_same_v) { return static_cast(x); +#endif } else if constexpr (std::is_same_v && std::is_same_v) { return x.template convert(); +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v && std::is_same_v>) { return {x.x, x.y}; +#endif } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index dcf6c7aeeb4..c202da110be 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -29,6 +29,9 @@ class DnnlGemmWrapper { static constexpr dt to_dt() { if constexpr (std::is_same_v) return dt::f32; else if constexpr (std::is_same_v) return dt::f16; +#ifdef GGML_SYCL_HAS_BF16 + else if constexpr (std::is_same_v) return dt::bf16; +#endif else static_assert(0); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c02a41ad862..36923160d72 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2176,6 +2176,31 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif + +#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16) + // Fast path for bf16 src0 + if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) && + row_diff == src0->ne[1]) { + using bf16_t = sycl::ext::oneapi::bfloat16; + ggml_sycl_pool_alloc src1_as_bf16(ctx.pool(), src1_ncols*ne10); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst); + GGML_ASSERT(to_bf16_sycl != nullptr); + to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream); + } else { + stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t)); + } + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, + src0_dd_i, DnnlGemmWrapper::to_dt(), + src1_as_bf16.get(), DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); + return; + } +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); @@ -3783,6 +3808,51 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( } } +// Fused MoE TG fast path. Returns false to fall back to the per-expert loop below. +static bool ggml_sycl_mul_mat_id_mmvq_fused( + ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) +{ + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + if (ne12 != 1) return false; + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) return false; + if (ne10 != src0->ne[0] || ne10 % QK8_1 != 0) return false; + if (!ggml_is_contiguous(src1)) return false; + + // Reorder layout not supported; fall back. + const ggml_tensor_extra_gpu * src0_extra = + static_cast(src0->extra); + if (src0_extra && src0_extra->optimized_feature.reorder) return false; + + const int64_t n_ids_per_group = ids->ne[0]; + if (ids->ne[1] != 1) return false; + if (ne11 != 1 && ne11 != n_ids_per_group) return false; + + const queue_ptr stream = ctx.stream(); + const int src1_padded_cols = GGML_PAD((int) ne10, MATRIX_ROW_PADDING); + const int n_experts_used = (int) n_ids_per_group; + const int nrows = (int) src0->ne[1]; + + ggml_sycl_pool_alloc src1_q8_alloc(ctx.pool(), + (size_t) ne11 * src1_padded_cols * sizeof(block_q8_1) / QK8_1); + char * src1_ddq = src1_q8_alloc.get(); + quantize_row_q8_1_sycl( + (const float *) src1->data, src1_ddq, (int) ne10, (int) ne11, + src1_padded_cols, stream); + + const size_t bytes_per_qrow = (size_t) src1_padded_cols * sizeof(block_q8_1) / QK8_1; + const size_t src1_row_stride = (ne11 == 1) ? 0 : bytes_per_qrow; + + return ggml_sycl_mul_mat_vec_q_id( + src0->type, src0->data, src1_ddq, (const int32_t *) ids->data, + (float *) dst->data, (int) ne10, nrows, n_experts_used, + /*expert_weight_stride=*/ src0->nb[2], + /*dst_row_stride=*/ dst->nb[1], + src1_row_stride, stream); +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -3798,6 +3868,12 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; + if (ne12 == 1) { + if (ggml_sycl_mul_mat_id_mmvq_fused(ctx, src0, src1, ids, dst)) { + return; + } + } + std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; @@ -3848,8 +3924,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, } } } else { - ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + const int64_t n_routed_rows = ids->ne[1] * n_ids; + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0); src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index af22b98dddb..8fa2198f35a 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -537,9 +537,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -682,9 +682,9 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK8_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -798,9 +798,9 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -842,9 +842,9 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -1199,3 +1199,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(src1_ddf_i); GGML_UNUSED(ctx); } + +// src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj). +template +static void mul_mat_vec_q_moe( + const void * __restrict__ vx_base, const void * __restrict__ vy_base, + float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev, + const int ncols, const int nrows, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + const sycl::nd_item<3> & item_ct1) { + + const int expert_idx = item_ct1.get_group(1); + const int i02 = ids_dev[expert_idx]; + + const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride; + const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride; + float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void launch_mul_mat_vec_q_moe( + const void * vx_base, const void * vy, const int32_t * ids_dev, + float * dst_base, const int ncols, const int nrows, const int n_experts_used, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + dpct::queue_ptr stream) { + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_moe( + vx_base, vy, dst_base, ids_dev, ncols, nrows, + expert_weight_stride, dst_row_stride, src1_row_stride, item); + }); + }); +} + +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, + const void * vy, + const int32_t * ids_dev, + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, + size_t dst_row_stride, + size_t src1_row_stride, + dpct::queue_ptr stream) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q8_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q2_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q3_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q6_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_MXFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_NVFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-sycl/mmvq.hpp b/ggml/src/ggml-sycl/mmvq.hpp index 049b43d4535..d674dc1d61e 100644 --- a/ggml/src/ggml-sycl/mmvq.hpp +++ b/ggml/src/ggml-sycl/mmvq.hpp @@ -24,4 +24,20 @@ void ggml_sycl_op_mul_mat_vec_q( const int64_t src1_ncols, const int64_t src1_padded_row_size, const dpct::queue_ptr &stream); +// Requires standard (non-reorder) block layout for src0. +// Returns false if src0_type isn't handled; caller should fall back. +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, // start of stacked expert weights + const void * vy, // pre-quantized src1 (Q8_1) + const int32_t * ids_dev, // device-side int32, length n_experts_used + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, // bytes between experts in vx_base + size_t dst_row_stride, // bytes between dst rows + size_t src1_row_stride, // 0 = shared src1, else per-expert stride in bytes + dpct::queue_ptr stream); + #endif // GGML_SYCL_MMVQ_HPP diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index a641c100913..8fb41943525 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -4,7 +4,11 @@ namespace utils { template static constexpr bool is_arithmetic_v() { - return std::is_arithmetic_v || std::is_same_v || std::is_same_v; + return std::is_arithmetic_v || std::is_same_v +#ifdef GGML_SYCL_HAS_BF16 + || std::is_same_v +#endif + ; } } @@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#ifdef GGML_SYCL_HAS_BF16 case GGML_TYPE_BF16: set_rows_sycl( src0_d, src1_d, (char *)dst->data, @@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#endif case GGML_TYPE_Q8_0: set_rows_sycl_q(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream); break; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 702a249d754..d4acee8b1df 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -792,6 +792,7 @@ struct vk_device_struct { vk_pipeline pipeline_arange_f32; vk_pipeline pipeline_fill_f32; + vk_pipeline pipeline_fill_f16; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -4577,6 +4578,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ @@ -9844,6 +9846,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_fill_f32; } + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_fill_f16; + } return nullptr; default: return nullptr; @@ -15713,8 +15718,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); case GGML_OP_ARANGE: - case GGML_OP_FILL: return op->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; case GGML_OP_SCALE: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 54b9b327333..ff836615330 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -889,6 +889,7 @@ void process_shaders() { string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f16", "fill.comp", {{"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3de6258c74d..449eae808e4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -44,18 +44,9 @@ // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide -// mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 - -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 - -// Requires 32 threads per output (wg_size/outputs_per_wg == 32) -#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 -// Requires at least two (and multiple of 2) k-quant blocks per tile -#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -78,6 +69,7 @@ struct ggml_webgpu_shader_lib_context { bool inplace = false; bool overlap = false; bool src_overlap = false; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -202,6 +194,28 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { } }; +/** RMS_NORM + MUL **/ + +struct ggml_webgpu_rms_norm_mul_pipeline_key { + bool inplace; // rn_src == dst + bool overlap; // mul_src == dst + bool src_overlap; // rn_src == mul_src + + bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { + return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -248,6 +262,46 @@ struct ggml_webgpu_ssm_conv_pipeline_key { } }; +/** CONV 2D */ +struct ggml_webgpu_conv2d_pipeline_key { + ggml_type weight_type; + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const { + return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_conv2d_pipeline_key_hash { + size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.weight_type); + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + +/** Im2Col **/ +struct ggml_webgpu_im2col_pipeline_key { + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_im2col_pipeline_key_hash { + size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -390,12 +444,11 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; + uses_logit_softcap == other.uses_logit_softcap; } }; @@ -409,47 +462,37 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_webgpu_flash_attn_pipeline_key key; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; -}; - -struct ggml_webgpu_flash_attn_shader_decisions { +struct ggml_webgpu_flash_attn_decisions { uint32_t q_tile = 0; uint32_t kv_tile = 0; uint32_t wg_size = 0; }; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { - return 1u; - } +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; - // Head-dim specializations used by the tuned vec f16 path. - switch (key.head_dim_qk) { - case 64: - return 2u; - case 96: - return 4u; - case 128: - return 1u; - case 192: - return 2u; - case 576: - return 2u; - default: - return 1u; - } +inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( + const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.kv_type = context.src1->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = kv_direct; + key.has_mask = has_mask; + key.has_sinks = has_sinks; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -471,79 +514,20 @@ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lh return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; } -struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { - ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_reduce"; - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - variant += std::string("_wg") + std::to_string(context.max_wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - struct ggml_webgpu_flash_attn_blk_pipeline_key { - uint32_t q_tile; uint32_t kv_tile; - bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { - return q_tile == other.q_tile && kv_tile == other.kv_tile; - } + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; } }; struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_tile); ggml_webgpu_hash_combine(seed, key.kv_tile); return seed; } }; -struct ggml_webgpu_flash_attn_blk_shader_lib_context { - ggml_webgpu_flash_attn_blk_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_blk"; - - defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); - variant += std::string("_qt") + std::to_string(context.key.q_tile); - - defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); - variant += std::string("_kvt") + std::to_string(context.key.kv_tile); - - uint32_t wg_size = 1; - while ((wg_size << 1) <= context.max_wg_size) { - wg_size <<= 1; - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - variant += std::string("_wg") + std::to_string(wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -568,6 +552,41 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_pipeline_key & key) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!key.kv_direct) { + bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + } + if (key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + return kv_tile; +} + /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { @@ -610,7 +629,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t wg_size; - uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; }; @@ -778,16 +796,17 @@ class ggml_webgpu_shader_lib { std::unordered_map cumsum_pipelines; // key is fixed, no variants yet std::unordered_map row_norm_pipelines; // op/inplace + std::unordered_map - get_rows_pipelines; // src_type, vectorized + get_rows_pipelines; // src_type, vectorized std::unordered_map - unary_pipelines; // type/op/inplace + unary_pipelines; // type/op/inplace std::unordered_map - scale_pipelines; // inplace + scale_pipelines; // inplace std::unordered_map - solve_tri_pipelines; // type + solve_tri_pipelines; // type std::unordered_map - ssm_conv_pipelines; // type/vectorized + ssm_conv_pipelines; // type/vectorized std::unordered_map @@ -802,6 +821,8 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_pipelines; std::unordered_map @@ -831,6 +852,15 @@ class ggml_webgpu_shader_lib { rope_pipelines; std::unordered_map soft_max_pipelines; + std::unordered_map + conv2d_pipelines; + std::unordered_map + im2col_pipelines; + + std::unordered_map + rms_norm_mul_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -849,10 +879,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_row_norm_pipeline_key key = { - .op = context.dst->op, - .inplace = context.inplace, - }; + ggml_webgpu_row_norm_pipeline_key key = {}; + key.op = context.dst->op; + key.inplace = context.inplace; auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -908,9 +937,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, - .vec4 = context.src0->ne[0] % 4 == 0, - .i64_idx = context.src1->type == GGML_TYPE_I64 }; + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -955,7 +985,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + ggml_webgpu_set_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1062,10 +1094,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; - ggml_webgpu_get_rows_pipeline_key key = { - .src_type = context.src0->type, - .vectorized = (int) vectorized, - }; + ggml_webgpu_get_rows_pipeline_key key = {}; + key.src_type = context.src0->type; + key.vectorized = (int) vectorized; auto it = get_rows_pipelines.find(key); if (it != get_rows_pipelines.end()) { @@ -1115,8 +1146,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (key.src_type) - { + switch (key.src_type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1136,9 +1166,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC_TYPE=") + type_str); - } + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } } defines.push_back("BYTE_HELPERS"); @@ -1181,7 +1211,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; + ggml_webgpu_scale_pipeline_key key = {}; + key.inplace = context.inplace; auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1208,11 +1239,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_solve_tri_pipeline_key key = { - .type = context.dst->type, - .n = (int) context.src0->ne[0], - .k = (int) context.src1->ne[0], - }; + ggml_webgpu_solve_tri_pipeline_key key = {}; + key.type = context.dst->type; + key.n = (int) context.src0->ne[0]; + key.k = (int) context.src1->ne[0]; auto it = solve_tri_pipelines.find(key); if (it != solve_tri_pipelines.end()) { @@ -1250,10 +1280,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_ssm_conv_pipeline_key key = { - .type = context.dst->type, - .vectorized = context.src1->ne[0] == 4, - }; + ggml_webgpu_ssm_conv_pipeline_key key = {}; + key.type = context.dst->type; + key.vectorized = context.src1->ne[0] == 4; auto it = ssm_conv_pipelines.find(key); if (it != ssm_conv_pipelines.end()) { @@ -1293,11 +1322,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_gated_delta_net_pipeline_key key = { - .type = context.dst->type, - .s_v = (int) context.src2->ne[0], - .kda = context.src3->ne[0] == context.src2->ne[0], - }; + ggml_webgpu_gated_delta_net_pipeline_key key = {}; + key.type = context.dst->type; + key.s_v = (int) context.src2->ne[0]; + key.kda = context.src3->ne[0] == context.src2->ne[0]; auto it = gated_delta_net_pipelines.find(key); if (it != gated_delta_net_pipelines.end()) { @@ -1330,7 +1358,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + ggml_webgpu_pad_pipeline_key key = {}; + key.circular = ggml_get_op_params_i32(context.dst, 8) != 0; auto it = pad_pipelines.find(key); if (it != pad_pipelines.end()) { @@ -1357,15 +1386,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_vec_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - }; + ggml_webgpu_mul_mat_vec_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1373,7 +1400,8 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = "mul_mat_vec"; + std::string variant = "mul_mat_vec"; + const char * shader_src = wgsl_mul_mat_vec; // src0 type (matrix row) switch (context.src0->type) { @@ -1422,25 +1450,25 @@ class ggml_webgpu_shader_lib { defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; if (key.src0_type >= GGML_TYPE_Q2_K) { - tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { - tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } - auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; - decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; @@ -1451,15 +1479,14 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - .use_subgroup_matrix = context.supports_subgroup_matrix - }; + ggml_webgpu_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -1578,8 +1605,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, - .src1_type = context.src1->type }; + ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_legacy_pipelines.find(key); if (it != mul_mat_legacy_pipelines.end()) { @@ -1621,8 +1649,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (context.src0->type) - { + switch (context.src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1642,9 +1669,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } } defines.push_back("BYTE_HELPERS"); @@ -1689,10 +1716,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_id_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - }; + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -1782,13 +1808,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; - ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, - .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), - }; + ggml_webgpu_unary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = op; + key.is_unary = is_unary; + key.inplace = context.inplace; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -1852,14 +1877,50 @@ class ggml_webgpu_shader_lib { return unary_pipelines[key]; } + webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rms_norm_mul_pipeline_key key = {}; + key.inplace = context.inplace; + key.overlap = context.overlap; + key.src_overlap = context.src_overlap; + + auto it = rms_norm_mul_pipelines.find(key); + if (it != rms_norm_mul_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = "RMS_NORM_MUL"; + std::string variant = op_name; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rms_norm_mul_pipelines[key] = pipeline; + return rms_norm_mul_pipelines[key]; + } + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, - .src_overlap = context.src_overlap, - }; + ggml_webgpu_binary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = context.dst->op; + key.inplace = context.inplace; + key.overlap = context.overlap; + key.src_overlap = context.src_overlap; auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -1908,9 +1969,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_concat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_concat_pipeline_key key = {}; + key.type = context.dst->type; auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -1945,9 +2005,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_repeat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_repeat_pipeline_key key = {}; + key.type = context.dst->type; auto it = repeat_pipelines.find(key); if (it != repeat_pipelines.end()) { @@ -1985,16 +2044,16 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { - auto it = flash_attn_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } - std::vector defines; std::string variant = "flash_attn"; - switch (context.key.kv_type) { + switch (key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -2010,111 +2069,206 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(context.key.kv_type); + variant += std::string("_") + ggml_type_name(key.kv_type); - if (context.key.has_mask) { + if (key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (context.key.has_sinks) { + if (key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (context.key.uses_logit_softcap) { + if (key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (context.key.kv_direct) { + if (key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - if (context.key.has_mask && context.key.use_vec) { - defines.push_back("BLK"); - variant += "_blk"; - } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.use_vec) { - q_tile = 1; - kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; - const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - } - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + auto decisions = std::make_shared(); + decisions->q_tile = context.sg_mat_m; + + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + decisions->kv_tile = kv_tile; + decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - uint32_t wg_size = 0; - if (context.key.use_vec) { - wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - } else { - wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + std::vector defines; + std::string variant = "flash_attn_vec"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + defines.push_back("BLK"); + variant += "_mask_blk"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + defines.push_back("Q_TILE=1"); + + auto decisions = std::make_shared(); + decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + uint32_t vec_ne = 1u; + + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; - flash_attn_pipelines[context.key] = pipeline; - return flash_attn_pipelines[context.key]; - } - - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - auto it = flash_attn_blk_pipelines.find(context.key); + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_blk_pipeline_key key = {}; + key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_blk_pipelines[context.key] = pipeline; - return flash_attn_blk_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile)); + variant += std::string("_kvt") + std::to_string(key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant); + flash_attn_blk_pipelines[key] = pipeline; + return flash_attn_blk_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_reduce_pipeline( - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - auto it = flash_attn_vec_reduce_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.wg_size = context.max_wg_size; + auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_vec_reduce_pipelines[context.key] = pipeline; - return flash_attn_vec_reduce_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant); + flash_attn_vec_reduce_pipelines[key] = pipeline; + return flash_attn_vec_reduce_pipelines[key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_cpy_pipeline_key key = { - .src_type = context.src0->type, - .dst_type = context.dst->type, - }; + ggml_webgpu_cpy_pipeline_key key = {}; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; auto it = cpy_pipelines.find(key); if (it != cpy_pipelines.end()) { @@ -2166,11 +2320,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_glu_pipeline_key key = { - .glu_op = ggml_get_glu_op(context.dst), - .type = context.dst->type, - .split = (context.src1 != nullptr), - }; + ggml_webgpu_glu_pipeline_key key = {}; + key.glu_op = ggml_get_glu_op(context.dst); + key.type = context.dst->type; + key.split = (context.src1 != nullptr); auto it = glu_pipelines.find(key); if (it != glu_pipelines.end()) { @@ -2239,11 +2392,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_rope_pipeline_key key = { - .type = context.dst->type, - .inplace = context.inplace, - .has_ff = (context.src2 != nullptr), - }; + ggml_webgpu_rope_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; + key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); if (it != rope_pipelines.end()) { @@ -2288,12 +2440,11 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_soft_max_pipeline_key key = { - .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, - .has_mask = (context.src1 != nullptr), - .has_sink = (context.src2 != nullptr), - .inplace = context.inplace, - }; + ggml_webgpu_soft_max_pipeline_key key = {}; + key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; + key.has_mask = (context.src1 != nullptr); + key.has_sink = (context.src2 != nullptr); + key.inplace = context.inplace; auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2340,6 +2491,84 @@ class ggml_webgpu_shader_lib { return soft_max_pipelines[key]; } + webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_conv2d_pipeline_key key = {}; + key.weight_type = context.src0->type; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = conv2d_pipelines.find(key); + if (it != conv2d_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "conv_2d"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for CONV_2D shader"); + } + }; + + push_type_defines("WEIGHT", key.weight_type); + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_conv2d, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + conv2d_pipelines[key] = pipeline; + return conv2d_pipelines[key]; + } + + webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_im2col_pipeline_key key = {}; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = im2col_pipelines.find(key); + if (it != im2col_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "im2col"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for IM2COL shader"); + } + }; + + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_im2col, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + im2col_pipelines[key] = pipeline; + return im2col_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, @@ -2359,25 +2588,6 @@ class ggml_webgpu_shader_lib { pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } - - static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; - } }; #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 01637e2ddab..acc486cfdda 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" +#include "ggml.h" #ifdef __EMSCRIPTEN__ # include @@ -41,6 +42,12 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim wg_x = CEIL_DIV(total_wg, wg_y); } +static inline uint32_t ggml_webgpu_u32_from_f32(float value) { + uint32_t bits; + memcpy(&bits, &value, sizeof(bits)); + return bits; +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -175,6 +182,7 @@ struct webgpu_dispatch_desc { struct webgpu_capabilities { wgpu::Limits limits; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -204,6 +212,7 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) std::unordered_map cpu_time_ms; @@ -211,11 +220,6 @@ struct webgpu_global_context_struct { std::unordered_map cpu_detail_ms; #endif -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; -#endif - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; @@ -261,10 +265,12 @@ struct webgpu_context_struct { size_t memset_bytes_per_thread; #ifdef GGML_WEBGPU_GPU_PROFILE - wgpu::Buffer profile_timestamp_dev_buf; - wgpu::Buffer profile_timestamp_host_buf; - wgpu::QuerySet profile_timestamp_query_set; - uint32_t profile_timestamp_query_count = 0; + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; #endif ~webgpu_context_struct() { @@ -369,6 +375,96 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint32_t k_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + + return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); +} + +// Used to determine if two tensors are the same for in-place operations +static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +} + +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst + bool src_overlap; +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + + return flags; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, + wgpu::Buffer buffer, + uint64_t offset, + uint64_t size) { + wgpu::BindGroupEntry entry = {}; + entry.binding = binding; + entry.buffer = std::move(buffer); + entry.offset = offset; + entry.size = size; + return entry; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_tensor_bind_group_entry(webgpu_context & ctx, + uint32_t binding, + ggml_tensor * tensor) { + return ggml_webgpu_make_bind_group_entry(binding, ggml_webgpu_tensor_buf(tensor), + ggml_webgpu_tensor_align_offset(ctx, tensor), + ggml_webgpu_tensor_binding_size(ctx, tensor)); +} + /** End WebGPU object initializations */ /** WebGPU Actions */ @@ -480,10 +576,8 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); - entries.push_back({ .binding = params_binding_num, - .buffer = ctx->param_arena.buffer, - .offset = param_offset, - .size = ctx->param_arena.slot_size }); + entries.push_back(ggml_webgpu_make_bind_group_entry(params_binding_num, ctx->param_arena.buffer, param_offset, + ctx->param_arena.slot_size)); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); @@ -502,13 +596,17 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & #ifdef GGML_WEBGPU_GPU_PROFILE for (size_t i = 0; i < dispatches.size(); i++) { GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); - const uint32_t query_begin = ctx->profile_timestamp_query_count++; - const uint32_t query_end = ctx->profile_timestamp_query_count++; - wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, - .beginningOfPassWriteIndex = query_begin, - .endOfPassWriteIndex = query_end }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; + + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); pass.SetPipeline(dispatches[i].pipeline.pipeline); pass.SetBindGroup(0, bind_groups[i]); @@ -544,17 +642,19 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector entries = { - { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } - }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); - entries.push_back( - { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + wgpu::BindGroupEntry params_entry = {}; + params_entry.binding = 1; + params_entry.buffer = ctx->memset_params_buf; + params_entry.offset = 0; + params_entry.size = WEBGPU_PARAMS_BUF_SIZE_BYTES; + entries.push_back(params_entry); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); @@ -612,12 +712,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_GPU_PROFILE std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; double total_gpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { total_gpu += kv.second; } std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; std::cout << "\nggml_webgpu: gpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) << pct << "%)\n"; @@ -632,65 +732,11 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { delete backend; } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; -} - -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); -} - -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} - -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} - -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; - -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); - - return flags; -} - static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); @@ -712,14 +758,8 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -732,13 +772,12 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); @@ -772,29 +811,21 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, std::vector entries; uint32_t binding_index = 0; if (!inplace) { - entries.push_back({ .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); binding_index++; } - entries.push_back({ .binding = binding_index, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - entries.push_back({ .binding = binding_index + 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index + 1, dst)); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); @@ -832,14 +863,8 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -850,13 +875,12 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); @@ -888,18 +912,9 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); @@ -907,16 +922,179 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1; + + const uint32_t KW = src0->ne[0]; + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1]; + + const uint32_t IW = src1->ne[0]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2]; + + const uint32_t OW = dst->ne[1]; + const uint32_t OH = is_2D ? dst->ne[2] : 1; + + const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)); + const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0; + const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)); + const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)); + + const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)); + const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)); + const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0; + const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) : + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + si0, + si1, + si2, + si3, + so0, + so1, + so2, + so3, + + KW, + KH, + IC, + + IW, + IH, + N, + + OW, + OH, + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -944,18 +1122,9 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); @@ -971,15 +1140,14 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src4, ggml_tensor * src5, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .src3 = src3, - .src4 = src4, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = src3; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); @@ -1015,34 +1183,10 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(src3), - .offset = ggml_webgpu_tensor_align_offset(ctx, src3), - .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(src4), - .offset = ggml_webgpu_tensor_align_offset(ctx, src4), - .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(src5), - .offset = ggml_webgpu_tensor_align_offset(ctx, src5), - .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, - { .binding = 6, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, dst), }; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); @@ -1058,12 +1202,11 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct return std::nullopt; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = idx, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = idx; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); @@ -1086,25 +1229,14 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; if (decisions->i64_idx) { - entries.push_back({ .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back(ggml_webgpu_make_bind_group_entry(3, ctx->set_rows_dev_error_buf, 0, + ctx->set_rows_dev_error_buf.GetSize())); } uint32_t threads; @@ -1131,12 +1263,11 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = WEBGPU_MAX_WG_SIZE, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1160,20 +1291,9 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst) }; uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); @@ -1208,14 +1328,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q6_K: - use_fast = true; - break; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat - use_fast = !is_vec; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q2_K: + use_fast = true; break; default: break; @@ -1225,17 +1342,18 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline webgpu_pipeline pipeline; @@ -1270,18 +1388,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Build bind group entries std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; // Calculate workgroup dimensions @@ -1333,16 +1442,16 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; // Get or create pipeline - webgpu_pipeline gather_pipeline, main_pipeline; + webgpu_pipeline gather_pipeline; + webgpu_pipeline main_pipeline; std::vector dispatches; @@ -1380,22 +1489,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id_gather.wgsl std::vector gather_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1427,30 +1528,18 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id.wgsl std::vector main_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(4, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(5, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; // Calculate workgroup dimensions @@ -1486,11 +1575,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * mask, ggml_tensor * sinks, ggml_tensor * dst) { - float scale = *(float *) dst->op_params; - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float logit_softcap; - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + float scale = ggml_get_op_params_f32(dst, 0); + float max_bias = ggml_get_op_params_f32(dst, 1); + float logit_softcap = ggml_get_op_params_f32(dst, 2); if (logit_softcap != 0.0f) { scale /= logit_softcap; } @@ -1522,86 +1609,53 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) - *(uint32_t *) &max_bias, - *(uint32_t *) &logit_softcap, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(logit_softcap), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; uint32_t binding_index = 3; if (has_mask) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } if (has_sinks) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); + webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : + ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + + if (!use_vec) { + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); - - const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && - (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .use_vec = use_vec, - }; - - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - - auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1609,197 +1663,162 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - if (use_vec) { - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - const bool use_vec_reduce = nwg > 1u; - GGML_ASSERT(nrows <= UINT32_MAX); - - uint64_t tmp_stats_base = 0; - uint64_t tmp_size_bytes = 0; - wgpu::Buffer tmp_buf = {}; - uint64_t tmp_bind_offset = 0; - uint64_t tmp_bind_size = 0; - const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); - - if (use_vec_reduce) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - tmp_stats_base = tmp_data_elems; - tmp_size_bytes = - ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - GGML_ASSERT(tmp_stats_base <= UINT32_MAX); - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = scratch_offset; - tmp_bind_size = tmp_size_bytes; - scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); - } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); - } - - webgpu_pipeline blk_pipeline; - std::vector blk_params; - std::vector blk_entries; - if (use_blk) { - GGML_ASSERT(has_mask); - - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { - .key = - { - .q_tile = decisions->q_tile, - .kv_tile = decisions->kv_tile, - }, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); - - blk_params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) K->ne[1], // seq_len_kv - stride_mask3, // stride_mask3 - blk_nblk0, // nblk0 - blk_nblk1, // nblk1 - }; - blk_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, - { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, - }; - scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); - } + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + } - std::vector split_params = params; - if (use_blk) { - split_params.push_back(0u); // blk_base - split_params.push_back(blk_nblk0); // blk_nblk0 - split_params.push_back(blk_nblk1); // blk_nblk1 - } - split_params.push_back(0u); // tmp_data_base - split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base - split_params.push_back(nwg); // nwg - - std::vector split_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (has_mask) { + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = (uint32_t) Q->ne[1]; + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 }; - uint32_t split_binding_index = 3; - if (has_mask) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); - } - if (has_sinks) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - if (use_blk) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = blk_buf, - .offset = blk_entries[1].offset, - .size = blk_size_bytes }); - } + blk_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask)), + ggml_webgpu_make_bind_group_entry(1, blk_buf, scratch_offset, blk_size_bytes), + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (has_mask) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), + ggml_webgpu_tensor_binding_size(ctx, Q)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V)), + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask))); + } + if (has_sinks) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), + ggml_webgpu_tensor_align_offset(ctx, sinks), + ggml_webgpu_tensor_binding_size(ctx, sinks))); + } + if (has_mask) { split_entries.push_back( - { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - webgpu_pipeline reduce_pipeline; - std::vector reduce_params; - std::vector reduce_entries; - if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, - std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = - { - .head_dim_v = (uint32_t) V->ne[0], - .wg_size = reduce_wg_size, - }, - .max_wg_size = reduce_wg_size, - }; - reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); - - reduce_params = { - (uint32_t) nrows, // nrows - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) Q->ne[2], // n_heads - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst - nwg, // nwg - 0u, // tmp_data_base - (uint32_t) tmp_stats_base, // tmp_stats_base - }; - - reduce_entries = { - { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; - } + ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); + } + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, tmp_buf, tmp_bind_offset, tmp_bind_size)); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(dst), + ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst))); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + reduce_shader_ctx.max_wg_size = reduce_wg_size; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; - const uint64_t split_wg_total = (uint64_t) wg_x * nwg; - GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector dispatches; - - if (use_blk) { - dispatches.push_back({ - blk_pipeline, - std::move(blk_params), - std::move(blk_entries), - { blk_nblk0, blk_nblk1 * blk_batch_count } - }); - } + reduce_entries = { + ggml_webgpu_make_bind_group_entry(0, tmp_buf, tmp_bind_offset, tmp_size_bytes), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + } + + uint32_t wg_x = Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + + std::vector dispatches; + + if (has_mask) { dispatches.push_back({ - pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } + }); + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } }); - if (use_vec_reduce) { - dispatches.push_back({ - reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } - }); - } - - return ggml_backend_webgpu_build_multi(ctx, dispatches); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #endif // __EMSCRIPTEN__ @@ -1807,13 +1826,12 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); @@ -1844,10 +1862,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor float alpha_p = ggml_get_op_params_f32(dst, 2); float beta = ggml_get_op_params_f32(dst, 3); float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast(&alpha_n)); - params.push_back(*reinterpret_cast(&alpha_p)); - params.push_back(*reinterpret_cast(&beta)); - params.push_back(*reinterpret_cast(&eps)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_n)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_p)); + params.push_back(ggml_webgpu_u32_from_f32(beta)); + params.push_back(ggml_webgpu_u32_from_f32(eps)); break; } default: @@ -1856,25 +1874,19 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor } else if (dst->op == GGML_OP_CLAMP) { float clamp_min = ggml_get_op_params_f32(dst, 0); float clamp_max = ggml_get_op_params_f32(dst, 1); - params.push_back(*reinterpret_cast(&clamp_min)); - params.push_back(*reinterpret_cast(&clamp_max)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_min)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_max)); } else if (dst->op == GGML_OP_FILL) { float fill_val = ggml_get_op_params_f32(dst, 0); - params.push_back(*reinterpret_cast(&fill_val)); + params.push_back(ggml_webgpu_u32_from_f32(fill_val)); effective_src = dst; // fill simply fills dst } std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(effective_src), - .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src), - .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, effective_src), }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -1887,15 +1899,14 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = flags.inplace, - .overlap = flags.overlap, - .src_overlap = flags.src_overlap, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = flags.inplace; + shader_lib_ctx.overlap = flags.overlap; + shader_lib_ctx.src_overlap = flags.src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1944,38 +1955,18 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = merged_offset, - .size = merged_end - merged_offset, - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = src0_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = src1_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), + src0_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src0))); + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), + src1_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src1))); if (!flags.inplace && !flags.overlap) { - entries.push_back({ - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2012,26 +2003,16 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2059,21 +2040,14 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * (uint32_t) (dst->ne[2]) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2081,6 +2055,96 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static std::optional ggml_webgpu_rms_norm_mul(webgpu_context & ctx, + ggml_tensor * rn_src, + ggml_tensor * rn_dst, + ggml_tensor * mul_src0, + ggml_tensor * mul_src1, + ggml_tensor * dst) { + ggml_tensor * mul_src; + + if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) { + mul_src = mul_src1; + } else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) { + mul_src = mul_src0; + } else { + GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); + } + + bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); + bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); + + uint32_t offset_merged_rn_src = 0; + uint32_t offset_merged_mul_src = 0; + size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src); + size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src); + + if (src_overlap) { + size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + offset_merged_rn_src = + (uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type)); + offset_merged_mul_src = + (uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type)); + } + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), + offset_merged_rn_src, + offset_merged_mul_src, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)), + (uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) mul_src->ne[0], + (uint32_t) mul_src->ne[1], + (uint32_t) mul_src->ne[2], + (uint32_t) mul_src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader + }; + + std::vector entries; + + if (inplace || overlap) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + } else if (src_overlap) { + size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + size_t merged_end = + std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src), + mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src)); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; + shader_lib_ctx.overlap = overlap; + shader_lib_ctx.src_overlap = src_overlap; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); +} + static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2097,28 +2161,19 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); @@ -2129,14 +2184,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); @@ -2187,41 +2241,27 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, (uint32_t) src0->ne[2], (uint32_t) n_dims, (uint32_t) mode, - *(uint32_t *) &theta_scale, - *(uint32_t *) &attn_factor, - *(uint32_t *) &freq_scale, - *(uint32_t *) &ext_factor, - *(uint32_t *) &corr_dims[0], - *(uint32_t *) &corr_dims[1], + ggml_webgpu_u32_from_f32(theta_scale), + ggml_webgpu_u32_from_f32(attn_factor), + ggml_webgpu_u32_from_f32(freq_scale), + ggml_webgpu_u32_from_f32(ext_factor), + ggml_webgpu_u32_from_f32(corr_dims[0]), + ggml_webgpu_u32_from_f32(corr_dims[1]), (uint32_t) sections[0], (uint32_t) sections[1], (uint32_t) sections[2], (uint32_t) sections[3] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - uint32_t dst_binding = 2; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1) }; + uint32_t dst_binding = 2; if (has_freq_factor) { dst_binding = 3; - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); } if (!inplace) { - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2232,12 +2272,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); @@ -2265,29 +2304,20 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) ((int32_t *) dst->op_params)[1], // swapped - *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai - *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 2)), // alpha, for swiglu_oai + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; uint32_t dst_binding = 1; if (split) { dst_binding = 2; - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); } - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -2296,13 +2326,12 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2321,23 +2350,15 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &dst->op_params[1] // bias + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 1)) // bias }; // bindgroups unchanged - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2349,25 +2370,23 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -2389,39 +2408,29 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[2], has_mask ? (uint32_t) src1->ne[2] : 0, has_mask ? (uint32_t) src1->ne[3] : 0, - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &max_bias, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) } - }; - uint32_t binding_num = 1; + std::vector entries = { ggml_webgpu_make_bind_group_entry( + 0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)) }; + uint32_t binding_num = 1; if (has_mask) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_bind_group_entry(binding_num, ggml_webgpu_tensor_buf(src1), + ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1))); binding_num++; } if (has_sink) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, src2)); binding_num++; } if (!inplace) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, dst)); } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); @@ -2432,20 +2441,13 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); @@ -2455,13 +2457,12 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); @@ -2527,11 +2528,8 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * const uint32_t wg_x_init = std::min(total_wg_init, max_wg); const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); std::vector init_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) }; dispatches.push_back({ @@ -2580,12 +2578,9 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * nrows }; std::vector merge_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in }, - { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), align_in, size_in), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; const uint32_t total_wg_merge = nm * nrows; @@ -2607,23 +2602,14 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); @@ -2641,20 +2627,13 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor total_sum ? 1 : (uint32_t) src->ne[1], total_sum ? 1 : (uint32_t) src->ne[2] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); @@ -2662,15 +2641,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + return false; + } + + // additional constraints specific to this fusion + const ggml_tensor * rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + + return true; +} + // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode(webgpu_context ctx, + ggml_cgraph * cgraph, + int node_idx, + int & num_encoded_ops) { + ggml_tensor ** nodes = cgraph->nodes; + ggml_tensor * node = nodes[node_idx]; + if (ggml_is_empty(node)) { return std::nullopt; } if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { return std::nullopt; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; @@ -2713,6 +2725,13 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: + if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) { + num_encoded_ops = 2; + ggml_tensor * mul_node = nodes[node_idx + 1]; + return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node); + } else { + return ggml_webgpu_row_norm(ctx, src0, node); + } case GGML_OP_L2_NORM: return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2753,6 +2772,10 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_SUM: case GGML_OP_SUM_ROWS: return ggml_webgpu_sum_rows(ctx, src0, node); + case GGML_OP_CONV_2D: + return ggml_webgpu_conv_2d(ctx, src0, src1, node); + case GGML_OP_IM2COL: + return ggml_webgpu_im2col(ctx, src0, src1, node); default: return std::nullopt; } @@ -2785,7 +2808,7 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context & for (size_t i = 0; i < pipeline_names.size(); ++i) { // WebGPU timestamps are in ns; convert to ms. const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; - ctx->global_ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; } ctx->profile_timestamp_host_buf.Unmap(); @@ -2821,6 +2844,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str uint32_t num_inflight_batches = 0; bool contains_set_rows = false; bool batch_compute_passes = true; + int num_encoded_ops = 1; + int node_idx = 0; #ifdef GGML_WEBGPU_GPU_PROFILE ctx->profile_timestamp_query_count = 0; @@ -2833,11 +2858,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + while (node_idx < cgraph->n_nodes) { + if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -2862,6 +2887,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->param_arena.reset(); commands.clear(); } + + node_idx += num_encoded_ops; + num_encoded_ops = 1; } if (ctx->active_compute_pass) { @@ -2891,22 +2919,107 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str return GGML_STATUS_SUCCESS; } +struct ggml_backend_webgpu_event_context { + webgpu_global_context global_ctx; + wgpu::Future future; + bool recorded = false; +}; + +static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context; + + auto * event_ctx = new ggml_backend_webgpu_event_context(); + event_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + + auto * event = new ggml_backend_event; + event->device = device; + event->context = event_ctx; + return event; +} + +static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + delete static_cast(event->context); + delete event; +} + +static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + if (!event_ctx->recorded) { + return; + } + wgpu::WaitStatus status = + event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + if (status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + event_ctx->recorded = false; +} + +static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + + event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); + event_ctx->recorded = true; +} + +static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_UNUSED(backend); + ggml_backend_webgpu_device_event_synchronize(nullptr, event); +} + +static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_UNUSED(backend); + auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + + // Write aligned portion + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to memset the remaining bytes + size_t remaining_size = size % 4; + + // pack the remaining bytes into a uint32_t + uint32_t val32 = 0; + + for (size_t i = 0; i < remaining_size; i++) { + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; + } + // memset the remaining bytes + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); + } +} + +static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx); +} + static ggml_backend_i ggml_backend_webgpu_i = { /* .get_name = */ ggml_backend_webgpu_name, /* .free = */ ggml_backend_webgpu_free, - /* .set_tensor_async = */ NULL, + /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_webgpu_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_webgpu_event_record, + /* .event_wait = */ ggml_backend_webgpu_event_wait, /* .graph_optimize = */ NULL, }; @@ -3133,40 +3246,24 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * mask = tensor->src[3]; const ggml_tensor * sinks = tensor->src[4]; if (Q && K && V) { - GGML_UNUSED(sinks); - const bool kv_direct = (K->type == GGML_TYPE_F16) && - (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && - kv_vec_type_supported && (V->type == K->type); - if (use_vec) { - const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - const size_t limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!kv_direct) { - bytes_per_kv += std::max(Q->ne[0], V->ne[0]); - } - if (mask != nullptr) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; - if (kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= sg_mat_n; - } - } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = const_cast(Q); + shader_lib_ctx.src1 = const_cast(K); + shader_lib_ctx.src2 = const_cast(V); + shader_lib_ctx.src3 = const_cast(mask); + shader_lib_ctx.src4 = const_cast(sinks); + shader_lib_ctx.dst = const_cast(tensor); + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { + const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3271,8 +3368,9 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct } static ggml_guid_t ggml_backend_webgpu_guid(void) { - static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *) guid_str); + static ggml_guid guid = { 0x67, 0xc7, 0xa4, 0xb1, 0x78, 0x74, 0x4f, 0x51, + 0x9d, 0x65, 0x44, 0x6d, 0xe4, 0x1b, 0x82, 0x9a }; + return &guid; } static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { @@ -3330,6 +3428,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + ctx->webgpu_global_ctx->capabilities.supports_subgroups = + ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); #ifndef __EMSCRIPTEN__ // Accept f16 subgroup matrix configurations (square or non-square). @@ -3362,11 +3462,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #ifndef __EMSCRIPTEN__ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } #endif + if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { + required_features.push_back(wgpu::FeatureName::Subgroups); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif @@ -3781,6 +3884,15 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SOLVE_TRI: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; break; + case GGML_OP_CONV_2D: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + break; + case GGML_OP_IM2COL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; @@ -3874,9 +3986,9 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .supports_op = */ ggml_backend_webgpu_device_supports_op, /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft, /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_webgpu_device_event_new, + /* .event_free = */ ggml_backend_webgpu_device_event_free, + /* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize, }; /* End GGML Backend Device Interface */ @@ -3931,20 +4043,23 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - static ggml_backend_webgpu_reg_context ctx; - static ggml_backend_reg reg = { + // Intentionally leak the global registry context to avoid crashing inside + // Dawn/Vulkan static teardown during process exit. + static ggml_backend_webgpu_reg_context * ctx = new ggml_backend_webgpu_reg_context(); + + static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, + /* .context = */ ctx, }; - ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 0; + ctx->name = GGML_WEBGPU_NAME; + ctx->device_count = 0; // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend // registry. Recreating it on repeated registry lookups can invalidate // adapter/device references that are still held by the backend/device layer. - if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx != nullptr && ctx->webgpu_global_ctx->instance != nullptr) { return ® } @@ -3961,17 +4076,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); - ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); - ctx.webgpu_global_ctx->instance = std::move(inst); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx->webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx->webgpu_global_ctx->instance = std::move(inst); // Probe for adapter support wgpu::Adapter adapter; - if (ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - ctx.webgpu_global_ctx->instance.WaitAny( - ctx.webgpu_global_ctx->instance.RequestAdapter( + // probe for adapter support + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { @@ -3984,7 +4100,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } if (adapter != nullptr) { - ctx.device_count = 1; + ctx->device_count = 1; } return ® diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 62fe72ee3b1..14c045b0ba6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -45,6 +45,13 @@ fn load_u16_at_src0(byte_offset: u32) -> u32 { return (word >> shift) & 0xFFFFu; } +// Always reads the 4-byte-aligned word containing byte_offset. +// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u. +// this is used in k-quants for better performance +fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 { + return src0[(byte_offset & ~3u) / 4u]; +} + fn load_u32_at_src0(byte_offset: u32) -> u32 { let word_idx = byte_offset / 4u; let shift = (byte_offset & 0x3u) * 8u; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl new file mode 100644 index 00000000000..9eb131dc221 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl @@ -0,0 +1,165 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(WEIGHT_F32) +var weights: array; +#elif defined(WEIGHT_F16) +var weights: array; +#endif + +@group(0) @binding(1) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(2) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_w: u32, + offset_i: u32, + offset_o: u32, + + // element strides + sw0: u32, sw1: u32, sw2: u32, sw3: u32, + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + // kernel dimensions + KW: u32, KH: u32, IC: u32, + // input dimensions + IW: u32, IH: u32, + // output dimensions + OW: u32, OH: u32, OC_out: u32, N_out: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +}; + +@group(0) @binding(3) +var params: Params; + +fn load_weight(idx: u32) -> f32 { + #if defined(WEIGHT_F32) + return weights[idx]; + #elif defined(WEIGHT_F16) + return f32(weights[idx]); + #endif +} + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +fn ceil_div_u32(x: u32, y: u32) -> u32 { + return (x + y - 1) / y; +} + +// returns the first valid kernel index k such that base + k * step >= 0 +fn first_valid_k(base: i32, step: u32) -> u32 { + if (base >= 0) { + return 0; + } + + return ceil_div_u32(u32(-base), step); +} + +// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k) +fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 { + let remaining = i32(limit) - base; + if (remaining <= 0) { + return 0; + } + + return min(k_max, ceil_div_u32(u32(remaining), step)); +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let n_out = params.OW * params.OH * params.OC_out * params.N_out; + + var sum: f32 = 0.0; + if (i_out >= n_out) { + return; + } + + // Kernel layout: [KW, KH, IC, ..] + // Input layout: [IW, IH, .., ..] + // Output layout: [OW, OH, OC, N] + + var i = i_out; + let n = i / (params.OC_out * params.OH * params.OW); + i = i % (params.OC_out * params.OH * params.OW); + let oc = i / (params.OH * params.OW); + i = i % (params.OH * params.OW); + let oh = i / params.OW; + let ow = i % params.OW; + + let ow_base = i32(ow * params.s0) - i32(params.p0); + let oh_base = i32(oh * params.s1) - i32(params.p1); + + // clip the valid kernel window once + let kw_begin = first_valid_k(ow_base, params.d0); + let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW); + let kh_begin = first_valid_k(oh_base, params.d1); + let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH); + + // entire receptive field is out of bounds + if (kw_begin >= kw_end || kh_begin >= kh_end) { + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, 0.0); + return; + } + + let weight_oc_base = params.offset_w + oc * params.sw3; + let input_n_base = params.offset_i + n * params.si3; + + for (var ic: u32 = 0; ic < params.IC; ic += 1) { + let w_base_ic = ic * params.sw2 + weight_oc_base; + let in_base = ic * params.si2 + input_n_base; + + for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) { + let ih = u32(oh_base + i32(kh * params.d1)); + let w_row_base = w_base_ic + kh * params.sw1; + let in_row_base = in_base + ih * params.si1; + for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) { + let iw = u32(ow_base + i32(kw * params.d0)); + let w_idx = w_row_base + kw * params.sw0; + let in_idx = in_row_base + iw * params.si0; + sum += load_weight(w_idx) * load_input(in_idx); + } + } + } + + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, sum); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 82d072be73a..61107c6a985 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -1,7 +1,6 @@ diagnostic(off, subgroup_uniformity); enable f16; -#define Q_TILE 1 #define KV_TILE 32 #define WG_SIZE 32 @@ -11,7 +10,7 @@ struct Params { seq_len_kv: u32, stride_mask3: u32, // Number of KV blocks and Q blocks per batch. - // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q. nblk0: u32, nblk1: u32, }; @@ -40,7 +39,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, return; } - let q_start = q_blk * Q_TILE; + let q_start = q_blk; let k_start = kv_blk * KV_TILE; let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); @@ -54,11 +53,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var local_max = -MASK_MAX; var local_any = 0u; - for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { - let q_row = q_start + q_rel; - if (q_row >= params.seq_len_q) { - continue; - } + let q_row = q_start; + if (q_row < params.seq_len_q) { let row_base = mask_batch_base + q_row * params.seq_len_kv; for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { let k_col = k_start + k_rel; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl new file mode 100644 index 00000000000..386ebab879f --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl @@ -0,0 +1,101 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(1) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + KW: u32, KH: u32, IC: u32, + IW: u32, IH: u32, N: u32, + OW: u32, OH: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +} + +@group(0) @binding(2) +var params: Params; + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let K = params.KW * params.KH * params.IC; + let M = params.OW * params.OH; + let total = K * M * params.N; + + if (i_out >= total) { + return; + } + + // decode (k, m, n) + var i = i_out; + let n = i / (K * M); + i = i % (K * M); + let m = i / K; + let k = i % K; + + // decode (oh, ow) + let oh = m / params.OW; + let ow = m % params.OW; + + // decode (kw, kh, ic) + let kw = k % params.KW; + let tmp = k / params.KW; + let kh = tmp % params.KH; + let ic = tmp / params.KH; + + let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0); + let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1); + + if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) { + let iw = u32(iw_i32); + let ih = u32(ih_i32); + let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3; + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx)); + } else { + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 9f7b3e32eca..97c9f6d7a09 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,465 +1,865 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif enable f16; #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 -#ifdef VEC +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif -#define VEC_SIZE 4 -#define DST_TYPE vec4 +#ifdef VEC +#define VEC_SIZE 4u #define SRC0_TYPE vec4 #define SRC1_TYPE vec4 fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(dot(SRC1_TYPE(src0_val), src1_val)); } - -fn store_val(group_base: u32) -> vec4 { - return vec4(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); -} #endif #ifdef SCALAR - -#define VEC_SIZE 1 -#define DST_TYPE f32 +#define VEC_SIZE 1u #define SRC0_TYPE SRC0_INNER_TYPE #define SRC1_TYPE SRC1_INNER_TYPE fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); } +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; +@group(0) @binding(0) var src0: array; +@group(0) @binding(1) var src1: array; +@group(0) @binding(2) var dst: array; + +@group(0) @binding(3) var params: MulMatParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; } + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 #endif +) { + let thread_id = local_id.x; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; + + var acc: array; #ifdef MUL_ACC_FLOAT -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { - let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; - let b = shared_vector[i / VEC_SIZE]; - local_sum += inner_dot(a, b); + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 18u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 20u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 22u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let qh_packed = load_u32_at_src0(block_byte_base + 2u); - - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at_src0(q_byte_offset); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } - } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 24u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = load_f16_at_src0(block_byte_base + 2u); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); - - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at_src0(q_byte_offset); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } - } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 34u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 36u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = load_f16_at_src0(block_byte_base + 2u); - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + f32(m); - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif -#ifdef MUL_ACC_Q6_K - -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_f16_at_src0(block_byte_base + 80u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } +#endif -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let tid = tig / 2u; - let ix = tig % 2u; - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 - let nb = tile_size / BLOCK_SIZE; - let k_block_start = k_outer / BLOCK_SIZE; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - // Aligned scale byte position (is can be odd) - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; - var local_sum = 0.0; + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; - for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; - let d = f32(load_f16_at_src0(bbase + 208u)); + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; - let ql1_u32 = load_u32_at_src0(bbase + q_offset_l); - let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u); - let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h); - let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte); - let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u); + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } - var sums = vec4(0.0, 0.0, 0.0, 0.0); + let num_blocks = params.k / BLOCK_SIZE; - for (var l = 0u; l < 4u; l++) { - let y_base = i * BLOCK_SIZE + y_offset + l; - let yl0 = f32(shared_vector[y_base]); - let yl1 = f32(shared_vector[y_base + 32u]); - let yl2 = f32(shared_vector[y_base + 64u]); - let yl3 = f32(shared_vector[y_base + 96u]); - - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += yl0 * dq0; - sums[1] += yl1 * dq1; - sums[2] += yl2 * dq2; - sums[3] += yl3 * dq3; + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); } - local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); + let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); + let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); + let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); + let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); + let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } } - - return local_sum; -} #endif -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) - -@group(0) @binding(3) var params: MulMatParams; - -const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } -// Shared memory for collaborative loading and reduction -var shared_vector: array; // Cache vector tile -var partial_sums: array; // For reduction + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } +#endif -@compute @workgroup_size(WG_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let thread_id = local_id.x; +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } - // Handle batch dimensions - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_groups; - if (batch_idx >= total_batches) { - return; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } } +#endif - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - let dst2_stride = params.m * params.n; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; - var local_sum = 0.0; + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 208u)); + let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); + let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } +#endif - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { - shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } + } - workgroupBarrier(); + workgroupBarrier(); - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif - workgroupBarrier(); +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; } - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset: u32 = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } } - offset = offset / 2; + workgroupBarrier(); + stride = stride / 2; } - // Store back to global memory - if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { - dst[dst_idx / VEC_SIZE] = store_val(group_base); + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } } +#endif } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl new file mode 100644 index 00000000000..74aaa2753ae --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -0,0 +1,154 @@ +#ifdef OVERLAP + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif SRC_OVERLAP + +@group(0) @binding(0) +var merged_src: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset]; +} + +#else + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#endif + +struct Params { + offset_rn_src: u32, + offset_mul_src: u32, + offset_merged_rn_src: u32, + offset_merged_mul_src: u32, + offset_dst: u32, + + stride_rn_src1: u32, + stride_rn_src2: u32, + stride_rn_src3: u32, + + stride_mul_src1: u32, + stride_mul_src2: u32, + stride_mul_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + mul_src_ne0: u32, + mul_src_ne1: u32, + mul_src_ne2: u32, + mul_src_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } +#ifdef SRC_OVERLAP + sum += pow(merged_src[i_rn_src_row + col], 2.0); +#else + sum += pow(rn_src[i_rn_src_row + col], 2.0); +#endif + col += WG_SIZE; + } + + scratch[lid.x] = sum; + + workgroupBarrier(); + + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); + + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0); + col += WG_SIZE; + } +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eda041f4518..54d3eae3e4d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7656,7 +7656,7 @@ size_t ggml_quantize_chunk( int64_t nrows, int64_t n_per_row, const float * imatrix) { - const int64_t n = (int64_t) nrows * n_per_row; + const int64_t n = nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); @@ -7673,21 +7673,21 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_1: result = quantize_q5_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0: result = quantize_q8_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_K: result = quantize_q2_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_K: result = quantize_q3_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_K: result = quantize_q4_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_K: result = quantize_q5_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -7752,9 +7752,9 @@ struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { } bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5297a2f440..c3b3cb37fae 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -152,6 +152,9 @@ class LLM: SWIGLU_CLAMP_SHEXP = "{arch}.swiglu_clamp_shexp" DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" + EAGLE3_EXTRACT_LAYERS = "{arch}.extract_layers" + EAGLE3_TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size" + EAGLE3_NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -197,6 +200,7 @@ class Rope: FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ALPHA = "{arch}.rope.scaling.alpha" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" @@ -471,6 +475,7 @@ class MODEL_ARCH(IntEnum): ERNIE4_5_MOE = auto() HUNYUAN_MOE = auto() HUNYUAN_DENSE = auto() + HUNYUAN_VL = auto() SMOLLM3 = auto() GPT_OSS = auto() LFM2 = auto() @@ -488,6 +493,8 @@ class MODEL_ARCH(IntEnum): PANGU_EMBED = auto() MISTRAL3 = auto() MISTRAL4 = auto() + EAGLE3 = auto() + DFLASH = auto() PADDLEOCR = auto() MIMO2 = auto() STEP35 = auto() @@ -842,6 +849,13 @@ class MODEL_TENSOR(IntEnum): NEXTN_HNORM = auto() NEXTN_SHARED_HEAD_HEAD = auto() NEXTN_SHARED_HEAD_NORM = auto() + # EAGLE3 specific tensors + EAGLE3_FC = auto() # feature fusion layer + EAGLE3_HIDDEN_NORM = auto() # hidden normalization + EAGLE3_D2T = auto() # draft to target vocabulary mapping + # DFlash + DFLASH_FC = auto() # feature fusion layer + DFLASH_HIDDEN_NORM = auto() # hidden normalization # lfm2 audio A_ENC_NORM_CONV = auto() A_ENC_LINEAR_POS = auto() @@ -957,6 +971,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense", + MODEL_ARCH.HUNYUAN_VL: "hunyuan_vl", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.GPT_OSS: "gpt-oss", MODEL_ARCH.LFM2: "lfm2", @@ -974,6 +989,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", MODEL_ARCH.MISTRAL4: "mistral4", + MODEL_ARCH.EAGLE3: "eagle3", + MODEL_ARCH.DFLASH: "dflash", MODEL_ARCH.PADDLEOCR: "paddleocr", MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.STEP35: "step35", @@ -1337,6 +1354,11 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm", MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head", MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm", + MODEL_TENSOR.EAGLE3_FC: "fc", + MODEL_TENSOR.EAGLE3_HIDDEN_NORM: "blk.{bid}.hidden_norm", + MODEL_TENSOR.EAGLE3_D2T: "d2t", + MODEL_TENSOR.DFLASH_FC: "fc", + MODEL_TENSOR.DFLASH_HIDDEN_NORM: "hidden_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -3489,6 +3511,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.HUNYUAN_VL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.SMOLLM3: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3723,6 +3761,40 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.EAGLE3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.EAGLE3_FC, + MODEL_TENSOR.EAGLE3_HIDDEN_NORM, + MODEL_TENSOR.EAGLE3_D2T, + ], + MODEL_ARCH.DFLASH: [ + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.DFLASH_FC, + MODEL_TENSOR.DFLASH_HIDDEN_NORM, + ], MODEL_ARCH.MISTRAL4: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -4138,6 +4210,7 @@ class VisionProjectorType: YOUTUVL = "youtuvl" NEMOTRON_V2_VL = "nemotron_v2_vl" HUNYUANOCR = "hunyuanocr" + HUNYUANVL = "hunyuanvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 90d500dc771..6a81ca37d8c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -973,6 +973,9 @@ def add_rope_scaling_type(self, value: RopeScalingType) -> None: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_alpha(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_ALPHA.format(arch=self.arch), value) + def add_rope_scaling_attn_factors(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) diff --git a/include/llama.h b/include/llama.h index ac267b5089a..fc629fd5c55 100644 --- a/include/llama.h +++ b/include/llama.h @@ -375,6 +375,10 @@ extern "C" { // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + // EAGLE3 extraction configuration + const struct llama_model * target_model; // reference to target model + // only used to share embedding layer with eagle3 model + // [EXPERIMENTAL] // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) @@ -511,27 +515,6 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); - enum llama_params_fit_status { - LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path - }; - - // fits mparams and cparams to free device memory (assumes system memory is unlimited) - // - returns true if the parameters could be successfully modified to fit device memory - // - this function is NOT thread safe because it modifies the global llama logger state - // - only parameters that have the same value as in llama_default_model_params are modified - // with the exception of the context size which is modified if and only if equal to 0 - LLAMA_API enum llama_params_fit_status llama_params_fit( - const char * path_model, - struct llama_model_params * mparams, - struct llama_context_params * cparams, - float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements - struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t * margins, // margins of memory to leave per device in bytes - uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use - enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log - LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); @@ -575,6 +558,12 @@ extern "C" { LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + // DFlash draft model: block size used as number of draft tokens + LLAMA_API int32_t llama_model_dflash_block_size(const struct llama_model * model); + + // DFlash draft model: mask token id used as filler in the noise block + LLAMA_API int32_t llama_model_dflash_mask_token_id(const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); @@ -707,6 +696,14 @@ extern "C" { int32_t il_start, int32_t il_end); + // + // eagle3 (tmp) + // + + LLAMA_API void llama_set_eagle3( + struct llama_context * ctx, + const struct llama_model * model); + // // Memory // @@ -906,6 +903,41 @@ extern "C" { llama_seq_id dest_seq_id, llama_state_seq_flags flags); + // + // EAGLE3 draft model support + // + + // Get pointer to target model features extracted for EAGLE3 encoder + // Returns NULL if no features are available + // Format: [3*n_embd, n_tokens] - use model.hparams.n_embd and batch.n_tokens for dimensions + LLAMA_API const float * llama_get_eagle3_target_features(struct llama_context * ctx); + + // Set g_embeddings from EAGLE3 encoder output for decoder input + // g_embd: pointer to encoder output embeddings + LLAMA_API void llama_set_eagle3_g_embeddings( + struct llama_context * ctx, + const float * g_embd, + int32_t n_embd, + int32_t n_tokens); + + // + // DFlash draft model support (similar to EAGLE3) + // + + // Enable DFlash target feature extraction on the target context + LLAMA_API void llama_set_dflash( + struct llama_context * ctx, + const struct llama_model * model); + + LLAMA_API const float * llama_get_dflash_target_features(struct llama_context * ctx); + + // Set accumulated target_ctx for DFlash decoder + LLAMA_API void llama_set_dflash_accumulated_target_ctx( + struct llama_context * ctx, + const float * data, + int32_t n_embd, + int32_t n_tokens); + // // Decoding // @@ -1546,9 +1578,6 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); - // print a breakdown of per-device memory use via LLAMA_LOG: - LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); - // // training // diff --git a/scripts/server-test-parallel-tc.py b/scripts/server-test-parallel-tc.py new file mode 100755 index 00000000000..a166c6d7208 --- /dev/null +++ b/scripts/server-test-parallel-tc.py @@ -0,0 +1,991 @@ +#!/usr/bin/env python3 +""" +Test parallel tool-calling capability via chat completions endpoint. + +Only run this against models that actually support parallel tool calls — this +script does not attempt to toggle that setting on the server. Each scenario is +explicitly worded so that a capable model SHOULD emit multiple tool calls in a +single assistant turn (either the same tool N times, or several different +tools at once). + +Each test case contains: + - tools: list of tool definitions (OpenAI-compatible) + - messages: initial conversation messages + - mock_tool_responses: dict mapping tool_name -> callable(arguments) -> str (JSON) + - expected_parallel: dict describing what constitutes a successful parallel turn + {"min_parallel": int, # minimum tool_calls in one turn + "require_same_tool": Optional[str], # all parallel calls must be this tool + "require_distinct_tools": Optional[int], # >= N distinct tool names in one turn + "min_distinct_args_key": Optional[str]} # parallel calls must span this + # many distinct values of this arg key + - validate: callable(turns, all_tool_calls, final_content) -> (passed, reason) +""" + +import argparse +import json +import requests +import sys + +# --------------------------------------------------------------------------- +# Color / formatting helpers +# --------------------------------------------------------------------------- + +RESET = "\x1b[0m" +BOLD = "\x1b[1m" +DIM = "\x1b[2m" +CYAN = "\x1b[36m" +YELLOW = "\x1b[33m" +GREEN = "\x1b[32m" +RED = "\x1b[31m" +BLUE = "\x1b[34m" +WHITE = "\x1b[97m" +MAGENTA = "\x1b[35m" + + +def _print(text="", end="\n"): + sys.stdout.write(text + end) + sys.stdout.flush() + + +def print_header(title): + bar = "─" * 60 + _print(f"\n{BOLD}{CYAN}┌{bar}┐{RESET}") + _print( + f"{BOLD}{CYAN}│ {WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}│{RESET}" + ) + _print(f"{BOLD}{CYAN}└{bar}┘{RESET}") + + +def print_turn_banner(turn_idx, n_calls): + color = MAGENTA if n_calls >= 2 else DIM + _print(f"\n {BOLD}{color}▶ turn {turn_idx} — {n_calls} tool call(s){RESET}") + + +def print_tool_call(name, args): + args_str = json.dumps(args) + _print( + f" {BOLD}{YELLOW}⚙ {name}{RESET}{DIM}({args_str}){RESET}" + ) + + +def print_tool_result(result): + preview = result[:140] + ("…" if len(result) > 140 else "") + _print(f" {DIM}{BLUE}↳ {preview}{RESET}") + + +def print_model_output(text): + sys.stdout.write(text) + sys.stdout.flush() + + +def print_pass(reason): + _print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}") + + +def print_fail(reason): + _print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}") + + +def print_info(msg): + _print(f"{DIM}{msg}{RESET}") + + +def print_warn(msg): + _print(f"{BOLD}{YELLOW}⚠ {msg}{RESET}") + + +# --------------------------------------------------------------------------- +# HTTP helpers +# --------------------------------------------------------------------------- + + +def chat_completion(url, messages, tools=None, stream=False): + payload = { + "messages": messages, + "stream": stream, + "max_tokens": 4096, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + + try: + response = requests.post(url, json=payload, stream=stream) + response.raise_for_status() + except requests.exceptions.RequestException as e: + body = e.response.content if (e.response is not None) else b"" + print_fail(f"Request error: {e} | body: {body}") + return None + + full_content = "" + reasoning_content = "" + tool_calls: list[dict] = [] + + if stream: + for line in response.iter_lines(): + if not line: + continue + decoded = line.decode("utf-8") + if not decoded.startswith("data: "): + continue + data_str = decoded[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + except json.JSONDecodeError: + continue + choices = data.get("choices", []) + if not choices: + continue + delta = choices[0].get("delta", {}) + if delta.get("reasoning_content"): + reasoning_content += delta["reasoning_content"] + if delta.get("content"): + full_content += delta["content"] + print_model_output(delta["content"]) + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + while len(tool_calls) <= idx: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + if "id" in tc: + tool_calls[idx]["id"] += tc["id"] + if "function" in tc: + if "name" in tc["function"]: + tool_calls[idx]["function"]["name"] += tc["function"]["name"] + if "arguments" in tc["function"]: + tool_calls[idx]["function"]["arguments"] += tc["function"][ + "arguments" + ] + else: + data = response.json() + choices = data.get("choices", []) + if choices: + msg = choices[0].get("message", {}) + full_content = msg.get("content") or "" + reasoning_content = msg.get("reasoning_content") or "" + tool_calls = msg.get("tool_calls") or [] + if full_content: + print_model_output(full_content) + + result = {"content": full_content, "tool_calls": tool_calls} + if reasoning_content: + result["reasoning_content"] = reasoning_content + return result + + +def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6): + """ + Drive the multi-turn tool-call loop, but record each turn's tool calls + separately so parallelism can be validated. + + Returns (turns, all_tool_calls, final_content) where `turns` is a list + of dicts: {"index": int, "tool_calls": [...], "content": str}. + """ + msgs = list(messages) + turns: list[dict] = [] + all_tool_calls: list[dict] = [] + + for turn_idx in range(max_turns): + result = chat_completion(url, msgs, tools=tools, stream=stream) + if result is None: + return turns, all_tool_calls, None + + tcs = result.get("tool_calls") or [] + content = result.get("content") or "" + + turns.append( + {"index": turn_idx, "tool_calls": list(tcs), "content": content} + ) + + if not tcs: + if content: + _print(f"\n{DIM}{'·' * 60}{RESET}") + _print(f"{DIM} model response:{RESET}\n") + return turns, all_tool_calls, content + + print_turn_banner(turn_idx, len(tcs)) + all_tool_calls.extend(tcs) + + assistant_msg: dict = { + "role": "assistant", + "content": content, + "tool_calls": tcs, + } + reasoning = result.get("reasoning_content") + if reasoning: + assistant_msg["reasoning_content"] = reasoning + msgs.append(assistant_msg) + + for tc in tcs: + tool_name = tc["function"]["name"] + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + args = {} + + print_tool_call(tool_name, args) + + mock_fn = mock_tool_responses.get(tool_name) + if mock_fn: + tool_result = mock_fn(args) + else: + tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"}) + + print_tool_result(tool_result) + + msgs.append( + { + "role": "tool", + "tool_call_id": tc.get("id", ""), + "content": tool_result, + } + ) + + return turns, all_tool_calls, None + + +# --------------------------------------------------------------------------- +# Parallelism helpers +# --------------------------------------------------------------------------- + + +def _best_parallel_turn(turns): + """Return the turn (dict) with the most tool calls, or None if no tools.""" + tool_turns = [t for t in turns if t["tool_calls"]] + if not tool_turns: + return None + return max(tool_turns, key=lambda t: len(t["tool_calls"])) + + +def _distinct_tool_names(turn): + return {tc["function"]["name"] for tc in turn["tool_calls"]} + + +def _distinct_arg_values(turn, key): + values = set() + for tc in turn["tool_calls"]: + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + continue + v = args.get(key) + if v is not None: + if isinstance(v, str): + values.add(v.strip().lower()) + else: + values.add(v) + return values + + +def _check_parallel(turns, expected): + """ + Check that at least one turn satisfies the parallel-call expectations. + Returns (ok, reason). + """ + best = _best_parallel_turn(turns) + if best is None: + return False, "No tool calls were made at all" + + min_parallel = expected.get("min_parallel", 2) + if len(best["tool_calls"]) < min_parallel: + by_turn = [len(t["tool_calls"]) for t in turns] + return False, ( + f"No turn had >= {min_parallel} parallel tool calls " + f"(per-turn counts: {by_turn})" + ) + + require_same = expected.get("require_same_tool") + if require_same is not None: + names = [tc["function"]["name"] for tc in best["tool_calls"]] + if any(n != require_same for n in names): + return False, ( + f"Parallel turn mixed tools; expected all {require_same!r}, got {names}" + ) + + require_distinct = expected.get("require_distinct_tools") + if require_distinct is not None: + distinct = _distinct_tool_names(best) + if len(distinct) < require_distinct: + return False, ( + f"Parallel turn had only {len(distinct)} distinct tool names " + f"({distinct}); need >= {require_distinct}" + ) + + distinct_key = expected.get("min_distinct_args_key") + distinct_count = expected.get("min_distinct_args_count", min_parallel) + if distinct_key is not None: + values = _distinct_arg_values(best, distinct_key) + if len(values) < distinct_count: + return False, ( + f"Parallel turn had only {len(values)} distinct {distinct_key!r} " + f"values ({values}); need >= {distinct_count}" + ) + + return True, ( + f"Parallel turn had {len(best['tool_calls'])} calls across " + f"{len(_distinct_tool_names(best))} distinct tool(s)" + ) + + +# --------------------------------------------------------------------------- +# Test case runner +# --------------------------------------------------------------------------- + + +def run_test(url, test_case, stream): + name = test_case["name"] + mode = f"{'stream' if stream else 'non-stream'}" + print_header(f"{name} [{mode}]") + + turns, all_tool_calls, final_content = run_agentic_loop( + url, + messages=test_case["messages"], + tools=test_case["tools"], + mock_tool_responses=test_case["mock_tool_responses"], + stream=stream, + ) + + if not turns: + print_fail("No response from server.") + return False + + parallel_ok, parallel_reason = _check_parallel(turns, test_case["expected_parallel"]) + if not parallel_ok: + print_fail(parallel_reason) + return False + + passed, reason = test_case["validate"](turns, all_tool_calls, final_content) + if passed: + print_pass(f"{parallel_reason}; {reason}") + else: + print_fail(reason) + return passed + + +# --------------------------------------------------------------------------- +# Test case definitions +# --------------------------------------------------------------------------- + +# ---- Test 1: Multi-file read (same tool, multiple distinct paths) ---- + +_FILE_TOOLS = [ + { + "type": "function", + "function": { + "name": "read_file", + "description": ( + "Read the full contents of a file from the local filesystem. " + "Call this tool in parallel when asked to read several files — " + "each path needs its own call." + ), + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute or repo-relative path to a file", + }, + }, + "required": ["path"], + }, + }, + }, +] + +_FILE_CONTENTS = { + "config/database.yml": "host: db.internal\nport: 5432\nuser: svc_app\n", + "config/redis.yml": "host: cache.internal\nport: 6379\ndb: 0\n", + "config/queue.yml": "broker: rabbitmq.internal\nport: 5672\nvhost: prod\n", + "config/auth.yml": "provider: oidc\nissuer: https://auth.internal\n", +} + + +def _read_file_mock(args): + path = args.get("path", "") + norm = path.lstrip("./").lstrip("/") + content = _FILE_CONTENTS.get(norm) + if content is None: + for k, v in _FILE_CONTENTS.items(): + if path.endswith(k): + content = v + break + if content is None: + return json.dumps({"path": path, "error": "not found"}) + return json.dumps({"path": path, "content": content}) + + +MULTIFILE_READ_TEST = { + "name": "Parallel multi-file read (same tool, 4 distinct paths)", + "tools": _FILE_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "Please read all four of these config files so I can review them " + "together: config/database.yml, config/redis.yml, config/queue.yml, " + "and config/auth.yml. Call read_file for every path in parallel in " + "a single batch — do NOT read them one by one sequentially across " + "turns. After you have all four, give me a one-line summary of each." + ), + } + ], + "mock_tool_responses": {"read_file": _read_file_mock}, + "expected_parallel": { + "min_parallel": 4, + "require_same_tool": "read_file", + "min_distinct_args_key": "path", + "min_distinct_args_count": 4, + }, + "validate": lambda turns, tcs, content: _validate_multifile(turns, tcs, content), +} + + +def _validate_multifile(turns, tcs, content): + del turns + if not content: + return False, "No final summary produced" + return True, f"{len(tcs)} total read_file calls; content length={len(content)}" + + +# ---- Test 2: Batch TODO marking (same tool, N calls in one turn) ---- + +_TODO_TOOLS = [ + { + "type": "function", + "function": { + "name": "mark_todo_complete", + "description": ( + "Mark a single TODO item as complete by ID. When the user wants " + "several items marked at once, call this tool in parallel — " + "one call per item — rather than sequentially across turns." + ), + "parameters": { + "type": "object", + "properties": { + "todo_id": { + "type": "string", + "description": "Identifier of the TODO item", + }, + "note": { + "type": "string", + "description": "Optional completion note", + }, + }, + "required": ["todo_id"], + }, + }, + }, +] + +_TODO_DB = { + "T-101": "Draft onboarding doc", + "T-102": "Update dependency lockfile", + "T-103": "Fix flaky login test", + "T-104": "Rotate service credentials", + "T-105": "Archive Q4 reports", +} + + +def _mark_todo_mock(args): + tid = args.get("todo_id", "") + if tid in _TODO_DB: + return json.dumps({"todo_id": tid, "title": _TODO_DB[tid], "status": "done"}) + return json.dumps({"todo_id": tid, "error": "unknown id"}) + + +TODO_BATCH_TEST = { + "name": "Batch TODO completion (same tool, 5 IDs in one turn)", + "tools": _TODO_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "I finished every item on today's list. Please mark all of the " + "following TODOs as complete, in one parallel batch: T-101, T-102, " + "T-103, T-104, T-105. Don't mark them one at a time across separate " + "turns — issue all five mark_todo_complete calls at once. Afterwards " + "confirm which ones succeeded." + ), + } + ], + "mock_tool_responses": {"mark_todo_complete": _mark_todo_mock}, + "expected_parallel": { + "min_parallel": 5, + "require_same_tool": "mark_todo_complete", + "min_distinct_args_key": "todo_id", + "min_distinct_args_count": 5, + }, + "validate": lambda turns, tcs, content: _validate_todo(turns, tcs, content), +} + + +def _validate_todo(turns, tcs, content): + del turns + if not content: + return False, "No confirmation summary produced" + return True, f"{len(tcs)} total mark_todo_complete calls" + + +# ---- Test 3: Multi-city weather (same tool, N parallel locations) ---- + +_WEATHER_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": ( + "Fetch current weather for ONE city. When the user asks about " + "several cities, call this tool in parallel — one call per city — " + "instead of sequentially." + ), + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "units": { + "type": "string", + "enum": ["metric", "imperial"], + "default": "metric", + }, + }, + "required": ["city"], + }, + }, + }, +] + +_WEATHER_DB = { + "tokyo": {"city": "Tokyo", "temp_c": 18.4, "condition": "partly cloudy", "humidity": 64}, + "london": {"city": "London", "temp_c": 9.1, "condition": "overcast", "humidity": 81}, + "new york": {"city": "New York", "temp_c": 12.7, "condition": "clear", "humidity": 55}, + "paris": {"city": "Paris", "temp_c": 11.3, "condition": "light rain", "humidity": 78}, +} + + +def _weather_mock(args): + city = args.get("city", "").strip().lower() + if city.startswith("new york"): + city = "new york" + if city in _WEATHER_DB: + return json.dumps(_WEATHER_DB[city]) + return json.dumps({"city": args.get("city", ""), "error": "unknown city"}) + + +MULTI_WEATHER_TEST = { + "name": "Parallel multi-city weather (same tool, 4 cities)", + "tools": _WEATHER_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "I'm comparing today's weather across four cities for a travel " + "decision: Tokyo, London, New York, and Paris. Please call " + "get_weather for all four in parallel in a single turn — don't " + "fetch them one at a time. Then rank them from warmest to coolest." + ), + } + ], + "mock_tool_responses": {"get_weather": _weather_mock}, + "expected_parallel": { + "min_parallel": 4, + "require_same_tool": "get_weather", + "min_distinct_args_key": "city", + "min_distinct_args_count": 4, + }, + "validate": lambda turns, tcs, content: _validate_weather(turns, tcs, content), +} + + +def _validate_weather(turns, tcs, content): + del turns + if not content or not any( + kw in content.lower() for kw in ("warmest", "rank", "hot", "cool") + ): + return False, f"Final content missing a ranking: {content!r}" + return True, f"{len(tcs)} total get_weather calls; ranking produced" + + +# ---- Test 4: Trip planning (different tools, parallel in one turn) ---- + +_TRIP_TOOLS = [ + { + "type": "function", + "function": { + "name": "search_flights", + "description": "Search one-way flights between two airports on a given date.", + "parameters": { + "type": "object", + "properties": { + "from_airport": {"type": "string", "description": "IATA code, e.g. SFO"}, + "to_airport": {"type": "string", "description": "IATA code, e.g. JFK"}, + "date": {"type": "string", "description": "YYYY-MM-DD"}, + }, + "required": ["from_airport", "to_airport", "date"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_hotels", + "description": "Search hotels in a city for a date range.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "check_in": {"type": "string", "description": "YYYY-MM-DD"}, + "check_out": {"type": "string", "description": "YYYY-MM-DD"}, + "max_price": {"type": "integer"}, + }, + "required": ["city", "check_in", "check_out"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_restaurants", + "description": "Search restaurants in a city by cuisine.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "cuisine": {"type": "string"}, + }, + "required": ["city"], + }, + }, + }, +] + +_FLIGHTS_RESULT = { + "results": [ + {"flight": "UA 1552", "depart": "08:15", "arrive": "16:45", "price": 389}, + {"flight": "AA 20", "depart": "10:00", "arrive": "18:35", "price": 412}, + ] +} +_HOTELS_RESULT = { + "results": [ + {"name": "Midtown Grand", "nightly_rate": 245, "rating": 4.3}, + {"name": "Harbour Boutique", "nightly_rate": 312, "rating": 4.6}, + ] +} +_RESTAURANTS_RESULT = { + "results": [ + {"name": "Trattoria Nona", "cuisine": "italian", "rating": 4.5}, + {"name": "Osteria Blu", "cuisine": "italian", "rating": 4.4}, + ] +} + +TRIP_PLAN_TEST = { + "name": "Trip planning (3 different tools in parallel)", + "tools": _TRIP_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "I'm flying from SFO to JFK on 2026-06-12 and staying four nights " + "(check out 2026-06-16). I'd also like some Italian restaurant " + "suggestions in New York. Please call search_flights, search_hotels, " + "and search_restaurants in parallel — all three in a single turn, " + "since they don't depend on each other. Then give me a concise " + "travel summary." + ), + } + ], + "mock_tool_responses": { + "search_flights": lambda _: json.dumps(_FLIGHTS_RESULT), + "search_hotels": lambda _: json.dumps(_HOTELS_RESULT), + "search_restaurants": lambda _: json.dumps(_RESTAURANTS_RESULT), + }, + "expected_parallel": { + "min_parallel": 3, + "require_distinct_tools": 3, + }, + "validate": lambda turns, tcs, content: _validate_trip(turns, tcs, content), +} + + +def _validate_trip(turns, tcs, content): + del turns + names = {tc["function"]["name"] for tc in tcs} + required = {"search_flights", "search_hotels", "search_restaurants"} + missing = required - names + if missing: + return False, f"Missing tool calls: {missing}" + if not content: + return False, "No travel summary produced" + return True, f"All three tools called; summary length={len(content)}" + + +# ---- Test 5: Portfolio check (same tool, parallel tickers) ---- + +_STOCK_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_stock_quote", + "description": ( + "Get the latest quote for ONE ticker. When the user asks about " + "multiple tickers, call this tool in parallel — one per symbol — " + "rather than sequentially." + ), + "parameters": { + "type": "object", + "properties": { + "symbol": {"type": "string", "description": "Ticker symbol"}, + }, + "required": ["symbol"], + }, + }, + }, +] + +_STOCK_DB = { + "AAPL": {"symbol": "AAPL", "price": 218.45, "change_pct": "+0.8%"}, + "MSFT": {"symbol": "MSFT", "price": 421.10, "change_pct": "+1.2%"}, + "GOOGL":{"symbol": "GOOGL","price": 175.22, "change_pct": "-0.3%"}, + "AMZN": {"symbol": "AMZN", "price": 189.76, "change_pct": "+0.5%"}, + "NVDA": {"symbol": "NVDA", "price": 140.88, "change_pct": "+2.4%"}, +} + + +def _stock_mock(args): + sym = args.get("symbol", "").strip().upper() + if sym in _STOCK_DB: + return json.dumps(_STOCK_DB[sym]) + return json.dumps({"symbol": sym, "error": "unknown ticker"}) + + +PORTFOLIO_TEST = { + "name": "Portfolio check (same tool, 5 tickers in parallel)", + "tools": _STOCK_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "Pull the latest quote for every ticker in my portfolio — AAPL, " + "MSFT, GOOGL, AMZN, and NVDA — in a single parallel batch. These " + "lookups are independent, so please don't chain them across turns. " + "Once you have all five, tell me which ticker had the biggest " + "percentage change today." + ), + } + ], + "mock_tool_responses": {"get_stock_quote": _stock_mock}, + "expected_parallel": { + "min_parallel": 5, + "require_same_tool": "get_stock_quote", + "min_distinct_args_key": "symbol", + "min_distinct_args_count": 5, + }, + "validate": lambda turns, tcs, content: _validate_portfolio(turns, tcs, content), +} + + +def _validate_portfolio(turns, tcs, content): + del turns + if not content or ("nvda" not in content.lower() and "NVDA" not in content): + return False, f"Expected NVDA to be identified as the biggest mover: {content!r}" + return True, f"{len(tcs)} total quotes pulled" + + +# ---- Test 6: Mixed — translate + dictionary in parallel for the same word ---- + +_LANG_TOOLS = [ + { + "type": "function", + "function": { + "name": "translate_text", + "description": "Translate a short text into a target language.", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "target_language": {"type": "string", + "description": "ISO 639-1 language code, e.g. 'es'"}, + }, + "required": ["text", "target_language"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_definition", + "description": "Get the English dictionary definition of a word.", + "parameters": { + "type": "object", + "properties": { + "word": {"type": "string"}, + }, + "required": ["word"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_synonyms", + "description": "Get English synonyms for a word.", + "parameters": { + "type": "object", + "properties": { + "word": {"type": "string"}, + }, + "required": ["word"], + }, + }, + }, +] + + +def _translate_mock(args): + t = args.get("text", "") + lang = args.get("target_language", "") + return json.dumps({"source": t, "target_language": lang, "translation": f"[{lang}] {t}"}) + + +def _definition_mock(args): + w = args.get("word", "") + return json.dumps({ + "word": w, + "definition": f"A standard dictionary definition of {w!r}.", + }) + + +def _synonyms_mock(args): + w = args.get("word", "") + return json.dumps({ + "word": w, + "synonyms": ["synonym_a", "synonym_b", "synonym_c"], + }) + + +LANG_TOOLKIT_TEST = { + "name": "Language toolkit (translate + definition + synonyms in parallel)", + "tools": _LANG_TOOLS, + "messages": [ + { + "role": "user", + "content": ( + "For the English word 'resilient', I need three independent " + "look-ups at once: (a) translate it into Spanish, (b) fetch its " + "dictionary definition, and (c) list its synonyms. These three " + "calls don't depend on each other — please issue them in parallel " + "in a single turn. Then present the combined results as a short " + "language note." + ), + } + ], + "mock_tool_responses": { + "translate_text": _translate_mock, + "get_definition": _definition_mock, + "get_synonyms": _synonyms_mock, + }, + "expected_parallel": { + "min_parallel": 3, + "require_distinct_tools": 3, + }, + "validate": lambda turns, tcs, content: _validate_lang(turns, tcs, content), +} + + +def _validate_lang(turns, tcs, content): + del turns + names = {tc["function"]["name"] for tc in tcs} + required = {"translate_text", "get_definition", "get_synonyms"} + missing = required - names + if missing: + return False, f"Missing tool calls: {missing}" + if not content: + return False, "No language note produced" + return True, f"All three lookup tools called; note length={len(content)}" + + +# --------------------------------------------------------------------------- +# All test cases +# --------------------------------------------------------------------------- + +ALL_TEST_CASES = [ + MULTIFILE_READ_TEST, + TODO_BATCH_TEST, + MULTI_WEATHER_TEST, + TRIP_PLAN_TEST, + PORTFOLIO_TEST, + LANG_TOOLKIT_TEST, +] + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Test llama-server parallel tool-calling capability. Run this only " + "against models configured for parallel tool calls — this script " + "does not configure that itself." + ) + ) + parser.add_argument("--host", default="localhost") + parser.add_argument("--port", default=8080, type=int) + parser.add_argument( + "--no-stream", action="store_true", help="Disable streaming mode tests" + ) + parser.add_argument( + "--stream-only", action="store_true", help="Only run streaming mode tests" + ) + parser.add_argument( + "--test", + help="Run only the test whose name contains this substring (case-insensitive)", + ) + args = parser.parse_args() + + url = f"http://{args.host}:{args.port}/v1/chat/completions" + print_info(f"Testing server at {url}") + print_warn( + "This script expects the target model to emit multiple tool calls in a " + "single assistant turn. Run it only against parallel-tool-capable models." + ) + + modes: list[bool] = [] + if not args.stream_only: + modes.append(False) + if not args.no_stream: + modes.append(True) + + cases: list[dict] = ALL_TEST_CASES + if args.test: + name_filter = args.test.lower() + cases = [c for c in cases if name_filter in str(c["name"]).lower()] + if not cases: + print_fail(f"No test cases matched '{args.test}'") + sys.exit(1) + + total = 0 + passed = 0 + for stream in modes: + for case in cases: + total += 1 + if run_test(url, case, stream=stream): + passed += 1 + + color = GREEN if passed == total else RED + _print(f"\n{BOLD}{color}{'─' * 60}{RESET}") + _print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}") + _print(f"{BOLD}{color}{'─' * 60}{RESET}\n") + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + main() diff --git a/scripts/server-test-structured.py b/scripts/server-test-structured.py new file mode 100755 index 00000000000..98ff473b9fe --- /dev/null +++ b/scripts/server-test-structured.py @@ -0,0 +1,980 @@ +#!/usr/bin/env python3 +""" +Test structured output capability via chat completions endpoint. + +Each test case contains: + - response_format: OpenAI-compatible response_format specification + (json_schema only — llama.cpp does not support json_object) + - messages: initial conversation messages + - tools (optional): tool definitions (for mixed tool + structured tests) + - mock_tool_responses (optional): dict mapping tool_name -> callable(arguments) -> str (JSON) + - apply_stage: "always" to apply response_format to every request, + "after_tools" to run the tool loop plain, then request a + structured summary in a follow-up user turn. + - followup (optional, for after_tools): user message appended before the + final structured call. + - validate: callable(parsed_json, tool_calls_history, raw_content) -> (passed: bool, reason: str) +""" + +import argparse +import json +import requests +import sys +from typing import Any, cast + +# --------------------------------------------------------------------------- +# Color / formatting helpers +# --------------------------------------------------------------------------- + +RESET = "\x1b[0m" +BOLD = "\x1b[1m" +DIM = "\x1b[2m" +CYAN = "\x1b[36m" +YELLOW = "\x1b[33m" +GREEN = "\x1b[32m" +RED = "\x1b[31m" +BLUE = "\x1b[34m" +WHITE = "\x1b[97m" +MAGENTA = "\x1b[35m" + + +def _print(text="", end="\n"): + sys.stdout.write(text + end) + sys.stdout.flush() + + +def print_header(title): + bar = "─" * 60 + _print(f"\n{BOLD}{CYAN}┌{bar}┐{RESET}") + _print( + f"{BOLD}{CYAN}│ {WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}│{RESET}" + ) + _print(f"{BOLD}{CYAN}└{bar}┘{RESET}") + + +def print_tool_call(name, args): + args_str = json.dumps(args) + _print( + f"\n {BOLD}{YELLOW}⚙ tool call{RESET} {CYAN}{name}{RESET}{DIM}({args_str}){RESET}" + ) + + +def print_tool_result(result): + preview = result[:160] + ("…" if len(result) > 160 else "") + _print(f" {DIM}{BLUE}↳ result{RESET} {DIM}{preview}{RESET}") + + +def print_model_output(text): + sys.stdout.write(text) + sys.stdout.flush() + + +def print_pass(reason): + _print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}") + + +def print_fail(reason): + _print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}") + + +def print_info(msg): + _print(f"{DIM}{msg}{RESET}") + + +def print_schema_note(label, rf): + kind = rf.get("type", "?") + name = "" + if kind == "json_schema": + name = rf.get("json_schema", {}).get("name", "") + _print(f"{DIM}{MAGENTA} ⟐ response_format [{label}]: {kind}" + f"{(' / ' + name) if name else ''}{RESET}") + + +# --------------------------------------------------------------------------- +# HTTP helpers +# --------------------------------------------------------------------------- + + +def chat_completion(url, messages, tools=None, response_format=None, stream=False): + payload = { + "messages": messages, + "stream": stream, + "max_tokens": 4096, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + if response_format is not None: + payload["response_format"] = response_format + + try: + response = requests.post(url, json=payload, stream=stream) + response.raise_for_status() + except requests.exceptions.RequestException as e: + body = e.response.content if (e.response is not None) else b"" + print_fail(f"Request error: {e} | body: {body}") + return None + + full_content = "" + reasoning_content = "" + tool_calls: list[dict] = [] + + if stream: + for line in response.iter_lines(): + if not line: + continue + decoded = line.decode("utf-8") + if not decoded.startswith("data: "): + continue + data_str = decoded[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + except json.JSONDecodeError: + continue + choices = data.get("choices", []) + if not choices: + continue + delta = choices[0].get("delta", {}) + if delta.get("reasoning_content"): + reasoning_content += delta["reasoning_content"] + if delta.get("content"): + full_content += delta["content"] + print_model_output(delta["content"]) + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + while len(tool_calls) <= idx: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + if "id" in tc: + tool_calls[idx]["id"] += tc["id"] + if "function" in tc: + if "name" in tc["function"]: + tool_calls[idx]["function"]["name"] += tc["function"]["name"] + if "arguments" in tc["function"]: + tool_calls[idx]["function"]["arguments"] += tc["function"][ + "arguments" + ] + else: + data = response.json() + choices = data.get("choices", []) + if choices: + msg = choices[0].get("message", {}) + full_content = msg.get("content") or "" + reasoning_content = msg.get("reasoning_content") or "" + tool_calls = msg.get("tool_calls") or [] + if full_content: + print_model_output(full_content) + + result = {"content": full_content, "tool_calls": tool_calls} + if reasoning_content: + result["reasoning_content"] = reasoning_content + return result + + +def run_tool_loop( + url, messages, tools, mock_tool_responses, stream, response_format=None, + max_turns=6, +): + """ + Drive the tool-call loop. If response_format is provided it is applied to + every request. Returns (all_tool_calls, final_messages, final_content). + """ + msgs = list(messages) + all_tool_calls: list[dict] = [] + + for _ in range(max_turns): + result = chat_completion( + url, msgs, tools=tools, response_format=response_format, stream=stream + ) + if result is None: + return all_tool_calls, msgs, None + + tcs = result.get("tool_calls") or [] + content = result.get("content") or "" + + if not tcs: + if content: + _print(f"\n{DIM}{'·' * 60}{RESET}") + return all_tool_calls, msgs, content + + all_tool_calls.extend(tcs) + + assistant_msg: dict = { + "role": "assistant", + "content": content, + "tool_calls": tcs, + } + reasoning = result.get("reasoning_content") + if reasoning: + assistant_msg["reasoning_content"] = reasoning + msgs.append(assistant_msg) + + for tc in tcs: + tool_name = tc["function"]["name"] + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + args = {} + + print_tool_call(tool_name, args) + + mock_fn = mock_tool_responses.get(tool_name) if mock_tool_responses else None + if mock_fn: + tool_result = mock_fn(args) + else: + tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"}) + + print_tool_result(tool_result) + + msgs.append( + { + "role": "tool", + "tool_call_id": tc.get("id", ""), + "content": tool_result, + } + ) + + return all_tool_calls, msgs, None + + +# --------------------------------------------------------------------------- +# Test case runner +# --------------------------------------------------------------------------- + + +def _try_parse_json(text): + """Attempt to parse text as JSON, trimming common markdown fences.""" + if text is None: + return None + stripped = text.strip() + if stripped.startswith("```"): + lines = stripped.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].strip().startswith("```"): + lines = lines[:-1] + stripped = "\n".join(lines).strip() + try: + return json.loads(stripped) + except json.JSONDecodeError: + return None + + +def run_test(url, test_case, stream): + name = test_case["name"] + mode = f"{'stream' if stream else 'non-stream'}" + apply_stage = test_case.get("apply_stage", "always") + print_header(f"{name} [{mode}] ({apply_stage})") + + response_format = test_case["response_format"] + print_schema_note(apply_stage, response_format) + + tools = test_case.get("tools") + mocks = test_case.get("mock_tool_responses") or {} + + all_tcs: list[dict] = [] + final_content = None + + if apply_stage == "always": + all_tcs, _msgs, final_content = run_tool_loop( + url, + messages=list(test_case["messages"]), + tools=tools, + mock_tool_responses=mocks, + stream=stream, + response_format=response_format, + ) + elif apply_stage == "after_tools": + # Phase 1: plain tool loop, no response_format applied yet. + all_tcs, msgs, interim_content = run_tool_loop( + url, + messages=list(test_case["messages"]), + tools=tools, + mock_tool_responses=mocks, + stream=stream, + response_format=None, + ) + if interim_content: + msgs.append({"role": "assistant", "content": interim_content}) + followup = test_case.get( + "followup", + "Now output the answer strictly as JSON matching the provided schema. " + "Do not include commentary.", + ) + msgs.append({"role": "user", "content": followup}) + + # Phase 2: request final structured output. Tools are not passed so the + # model focuses on producing the schema-constrained answer. + _print(f"\n{DIM}{MAGENTA} ⟐ follow-up turn with response_format applied{RESET}") + result = chat_completion( + url, msgs, tools=None, response_format=response_format, stream=stream + ) + final_content = result["content"] if result else None + else: + print_fail(f"Unknown apply_stage: {apply_stage}") + return False + + if final_content is None: + print_fail("No final content from server.") + return False + + parsed = _try_parse_json(final_content) + if parsed is None: + print_fail(f"Final content is not valid JSON: {final_content[:200]!r}") + return False + + passed, reason = test_case["validate"](parsed, all_tcs, final_content) + if passed: + print_pass(reason) + else: + print_fail(reason) + return passed + + +# --------------------------------------------------------------------------- +# Test case definitions +# --------------------------------------------------------------------------- + +# ---- Test 1: Book metadata extraction (always / json_schema) ---- + +_BOOK_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "book_metadata", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "title": {"type": "string"}, + "author": {"type": "string"}, + "year": {"type": "integer"}, + "genre": { + "type": "string", + "enum": [ + "fiction", + "non-fiction", + "fantasy", + "sci-fi", + "mystery", + "biography", + "history", + "other", + ], + }, + "page_count": {"type": "integer"}, + }, + "required": ["title", "author", "year", "genre", "page_count"], + }, + }, +} + +BOOK_TEST_CASE = { + "name": "Book metadata extraction (json_schema, always)", + "response_format": _BOOK_SCHEMA, + "apply_stage": "always", + "messages": [ + { + "role": "user", + "content": ( + "Extract book metadata from this description: " + "'Dune is a 1965 science fiction epic by Frank Herbert, spanning roughly " + "688 pages in its first edition, set on the desert planet Arrakis.' " + "Return the data as JSON." + ), + } + ], + "validate": lambda parsed, tcs, raw: _validate_book(parsed), +} + + +def _validate_book(parsed): + required = {"title", "author", "year", "genre", "page_count"} + missing = required - parsed.keys() + if missing: + return False, f"Missing fields: {missing}" + if not isinstance(parsed["title"], str) or not parsed["title"]: + return False, "title must be a non-empty string" + if not isinstance(parsed["author"], str) or "herbert" not in parsed["author"].lower(): + return False, f"author unexpected: {parsed['author']!r}" + if not isinstance(parsed["year"], int) or parsed["year"] != 1965: + return False, f"year should be 1965, got {parsed['year']!r}" + if parsed["genre"] not in { + "fiction", "non-fiction", "fantasy", "sci-fi", "mystery", + "biography", "history", "other", + }: + return False, f"genre not in enum: {parsed['genre']!r}" + if not isinstance(parsed["page_count"], int) or parsed["page_count"] <= 0: + return False, f"page_count should be positive int: {parsed['page_count']!r}" + return True, f"Book: {parsed['title']} ({parsed['year']}) / {parsed['genre']}" + + +# ---- Test 2: Sentiment classification (always / enum-constrained) ---- + +_SENTIMENT_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "sentiment_analysis", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "sentiment": { + "type": "string", + "enum": ["positive", "negative", "neutral"], + }, + "confidence": {"type": "number"}, + "keywords": { + "type": "array", + "items": {"type": "string"}, + "minItems": 1, + "maxItems": 5, + }, + }, + "required": ["sentiment", "confidence", "keywords"], + }, + }, +} + +SENTIMENT_TEST_CASE = { + "name": "Sentiment analysis with enum and array", + "response_format": _SENTIMENT_SCHEMA, + "apply_stage": "always", + "messages": [ + { + "role": "user", + "content": ( + "Analyse the sentiment of this review and return JSON with the " + "detected sentiment label, a confidence score between 0 and 1, " + "and up to five keyword strings that drove the classification:\n\n" + "'This product completely exceeded my expectations. The build " + "quality is phenomenal, it arrived a day early, and customer " + "support was delightful when I had a setup question.'" + ), + } + ], + "validate": lambda parsed, tcs, raw: _validate_sentiment(parsed), +} + + +def _validate_sentiment(parsed): + if parsed.get("sentiment") not in {"positive", "negative", "neutral"}: + return False, f"sentiment not in enum: {parsed.get('sentiment')!r}" + if parsed["sentiment"] != "positive": + return False, f"expected positive sentiment, got {parsed['sentiment']}" + conf = parsed.get("confidence") + if not isinstance(conf, (int, float)) or not (0.0 <= conf <= 1.0): + return False, f"confidence not in [0,1]: {conf!r}" + kws = parsed.get("keywords") + if not isinstance(kws, list) or not (1 <= len(kws) <= 5): + return False, f"keywords length out of range: {kws!r}" + if not all(isinstance(k, str) and k for k in kws): + return False, f"keywords must be non-empty strings: {kws!r}" + return True, f"sentiment={parsed['sentiment']} conf={conf} kws={kws}" + + +# ---- Test 3: Nested recipe schema (always) ---- + +_RECIPE_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "recipe", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": {"type": "string"}, + "servings": {"type": "integer"}, + "ingredients": { + "type": "array", + "minItems": 2, + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "item": {"type": "string"}, + "quantity": {"type": "string"}, + }, + "required": ["item", "quantity"], + }, + }, + "steps": { + "type": "array", + "minItems": 2, + "items": {"type": "string"}, + }, + "prep_time_minutes": {"type": "integer"}, + }, + "required": ["name", "servings", "ingredients", "steps", "prep_time_minutes"], + }, + }, +} + +RECIPE_TEST_CASE = { + "name": "Nested recipe with arrays of objects", + "response_format": _RECIPE_SCHEMA, + "apply_stage": "always", + "messages": [ + { + "role": "user", + "content": ( + "Give me a simple 4-serving scrambled eggs recipe as structured JSON. " + "Include the recipe name, servings, ingredients (each with item and " + "quantity), preparation steps, and total prep time in minutes." + ), + } + ], + "validate": lambda parsed, tcs, raw: _validate_recipe(parsed), +} + + +def _validate_recipe(parsed): + required = {"name", "servings", "ingredients", "steps", "prep_time_minutes"} + missing = required - parsed.keys() + if missing: + return False, f"Missing fields: {missing}" + if not isinstance(parsed["name"], str) or not parsed["name"]: + return False, "name must be a non-empty string" + if not isinstance(parsed["servings"], int) or parsed["servings"] <= 0: + return False, f"servings must be positive int: {parsed['servings']!r}" + ings = parsed["ingredients"] + if not isinstance(ings, list) or len(ings) < 2: + return False, f"ingredients must be array of >=2: got {ings!r}" + for i, ing in enumerate(ings): + if not isinstance(ing, dict): + return False, f"ingredient[{i}] is not an object: {ing!r}" + ing_d = cast(dict[str, Any], ing) + item_val = ing_d.get("item") + qty_val = ing_d.get("quantity") + if item_val is None or qty_val is None: + return False, f"ingredient[{i}] missing item/quantity: {ing!r}" + if not isinstance(item_val, str) or not isinstance(qty_val, str): + return False, f"ingredient[{i}] fields must be strings: {ing!r}" + steps = parsed["steps"] + if not isinstance(steps, list) or len(steps) < 2: + return False, f"steps must be array of >=2 strings: got {steps!r}" + if not all(isinstance(s, str) and s for s in steps): + return False, "all steps must be non-empty strings" + pt = parsed["prep_time_minutes"] + if not isinstance(pt, int) or pt <= 0: + return False, f"prep_time_minutes must be positive int: {pt!r}" + return True, f"recipe '{parsed['name']}' with {len(ings)} ingredients, {len(steps)} steps" + + +# ---- Test 4: Tool call -> structured product comparison (after_tools) ---- + +_SHOP_TOOLS = [ + { + "type": "function", + "function": { + "name": "search_products", + "description": "Search a product catalogue by keyword.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_product_details", + "description": "Get detailed specs for a product by ID.", + "parameters": { + "type": "object", + "properties": { + "product_id": {"type": "string"}, + }, + "required": ["product_id"], + }, + }, + }, +] + +_SHOP_SEARCH_RESULT = { + "results": [ + {"product_id": "LAP-001", "title": "AeroBook 13 Pro", "price": 1399.0, "rating": 4.7}, + {"product_id": "LAP-002", "title": "QuantumSlim 14", "price": 1199.0, "rating": 4.4}, + {"product_id": "LAP-003", "title": "NimbusWork Ultra 15", "price": 999.0, "rating": 4.2}, + ], +} +_SHOP_PRODUCT_DETAILS = { + "LAP-001": { + "product_id": "LAP-001", + "title": "AeroBook 13 Pro", + "cpu": "M-series 10-core", + "ram_gb": 16, + "storage_gb": 512, + "battery_hours": 18, + "weight_kg": 1.24, + "price": 1399.0, + }, + "LAP-002": { + "product_id": "LAP-002", + "title": "QuantumSlim 14", + "cpu": "Core i7 12-core", + "ram_gb": 16, + "storage_gb": 512, + "battery_hours": 12, + "weight_kg": 1.35, + "price": 1199.0, + }, + "LAP-003": { + "product_id": "LAP-003", + "title": "NimbusWork Ultra 15", + "cpu": "Ryzen 7 8-core", + "ram_gb": 16, + "storage_gb": 1024, + "battery_hours": 10, + "weight_kg": 1.70, + "price": 999.0, + }, +} + + +def _shop_details_mock(args): + pid = args.get("product_id", "") + if pid in _SHOP_PRODUCT_DETAILS: + return json.dumps(_SHOP_PRODUCT_DETAILS[pid]) + return json.dumps({"error": f"unknown product_id: {pid}"}) + + +_SHOP_COMPARISON_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "laptop_comparison", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "recommendation": {"type": "string"}, + "ranked_candidates": { + "type": "array", + "minItems": 2, + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "product_id": {"type": "string"}, + "title": {"type": "string"}, + "score": {"type": "number"}, + "reason": {"type": "string"}, + }, + "required": ["product_id", "title", "score", "reason"], + }, + }, + }, + "required": ["recommendation", "ranked_candidates"], + }, + }, +} + +SHOP_COMPARISON_TEST_CASE = { + "name": "Tool calls then structured laptop comparison (after_tools)", + "response_format": _SHOP_COMPARISON_SCHEMA, + "apply_stage": "after_tools", + "tools": _SHOP_TOOLS, + "mock_tool_responses": { + "search_products": lambda _: json.dumps(_SHOP_SEARCH_RESULT), + "get_product_details": _shop_details_mock, + }, + "messages": [ + { + "role": "user", + "content": ( + "I need a lightweight laptop for travel. Please search the catalogue " + "for 'ultraportable laptop', then fetch detailed specs for at least two " + "of the top candidates. Once you've gathered the data I'll ask you to " + "produce a structured comparison." + ), + } + ], + "followup": ( + "Thanks. Now produce the final comparison strictly as JSON matching the " + "laptop_comparison schema: your single best recommendation (the product_id), " + "and a ranked_candidates array of at least two laptops, each with " + "product_id, title, a numeric score, and a short reason." + ), + "validate": lambda parsed, tcs, raw: _validate_shop_comparison(parsed, tcs), +} + + +def _validate_shop_comparison(parsed, tcs): + names = [tc["function"]["name"] for tc in tcs] + if "search_products" not in names: + return False, f"expected search_products tool call, got {names}" + if "get_product_details" not in names: + return False, f"expected get_product_details tool call, got {names}" + if "recommendation" not in parsed or not isinstance(parsed["recommendation"], str): + return False, f"recommendation missing or not a string: {parsed!r}" + cands = parsed.get("ranked_candidates") + if not isinstance(cands, list) or len(cands) < 2: + return False, f"ranked_candidates must be >=2: {cands!r}" + valid_ids = set(_SHOP_PRODUCT_DETAILS.keys()) + candidate_pids: list = [] + for i, c in enumerate(cands): + if not isinstance(c, dict): + return False, f"candidate[{i}] not an object: {c!r}" + c_d = cast(dict[str, Any], c) + pid = c_d.get("product_id") + title = c_d.get("title") + score = c_d.get("score") + reason = c_d.get("reason") + for k, v in (("product_id", pid), ("title", title), + ("score", score), ("reason", reason)): + if v is None: + return False, f"candidate[{i}] missing {k}: {c!r}" + if pid not in valid_ids: + return False, f"candidate[{i}].product_id not in catalogue: {pid!r}" + if not isinstance(score, (int, float)): + return False, f"candidate[{i}].score not numeric: {score!r}" + candidate_pids.append(pid) + recommendation = parsed["recommendation"] + if recommendation not in valid_ids and recommendation not in candidate_pids: + return False, f"recommendation {recommendation!r} not in candidates" + return True, ( + f"tools={names}; recommended={parsed['recommendation']}; " + f"{len(cands)} ranked candidates" + ) + + +# ---- Test 5: Multi-step research then structured report (after_tools) ---- + +_RESEARCH_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_country_stats", + "description": "Fetch basic statistics for a country (population, GDP, capital).", + "parameters": { + "type": "object", + "properties": { + "country": {"type": "string"}, + }, + "required": ["country"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_climate_info", + "description": "Fetch climate information for a country.", + "parameters": { + "type": "object", + "properties": { + "country": {"type": "string"}, + }, + "required": ["country"], + }, + }, + }, +] + +_COUNTRY_STATS = { + "norway": { + "country": "Norway", + "capital": "Oslo", + "population": 5_480_000, + "gdp_usd_trillion": 0.48, + "currency": "NOK", + } +} +_CLIMATE_INFO = { + "norway": { + "country": "Norway", + "climate_zone": "subarctic / temperate coastal", + "avg_winter_temp_c": -4.5, + "avg_summer_temp_c": 16.0, + "annual_precipitation_mm": 1400, + } +} + + +def _country_stats_mock(args): + c = args.get("country", "").strip().lower() + if c in _COUNTRY_STATS: + return json.dumps(_COUNTRY_STATS[c]) + return json.dumps({"error": f"unknown country: {c}"}) + + +def _climate_info_mock(args): + c = args.get("country", "").strip().lower() + if c in _CLIMATE_INFO: + return json.dumps(_CLIMATE_INFO[c]) + return json.dumps({"error": f"unknown country: {c}"}) + + +_RESEARCH_REPORT_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "country_report", + "strict": True, + "schema": { + "type": "object", + "additionalProperties": False, + "properties": { + "country": {"type": "string"}, + "capital": {"type": "string"}, + "population": {"type": "integer"}, + "climate_summary": {"type": "string"}, + "highlights": { + "type": "array", + "minItems": 2, + "maxItems": 5, + "items": {"type": "string"}, + }, + "suitable_for_tourism": {"type": "boolean"}, + }, + "required": [ + "country", "capital", "population", + "climate_summary", "highlights", "suitable_for_tourism", + ], + }, + }, +} + +COUNTRY_REPORT_TEST_CASE = { + "name": "Research pipeline then structured country report (after_tools)", + "response_format": _RESEARCH_REPORT_SCHEMA, + "apply_stage": "after_tools", + "tools": _RESEARCH_TOOLS, + "mock_tool_responses": { + "get_country_stats": _country_stats_mock, + "get_climate_info": _climate_info_mock, + }, + "messages": [ + { + "role": "user", + "content": ( + "I'm preparing a short briefing on Norway. Please call the " + "get_country_stats and get_climate_info tools to gather data " + "first. Afterwards I'll ask for a structured summary." + ), + } + ], + "followup": ( + "Based on the tool results, produce the briefing as JSON matching the " + "country_report schema. Populate every required field and provide between " + "two and five highlights." + ), + "validate": lambda parsed, tcs, raw: _validate_country_report(parsed, tcs), +} + + +def _validate_country_report(parsed, tcs): + names = [tc["function"]["name"] for tc in tcs] + for required_tool in ("get_country_stats", "get_climate_info"): + if required_tool not in names: + return False, f"missing tool call {required_tool!r}: got {names}" + required = { + "country", "capital", "population", + "climate_summary", "highlights", "suitable_for_tourism", + } + missing = required - parsed.keys() + if missing: + return False, f"missing report fields: {missing}" + if "norway" not in parsed["country"].lower(): + return False, f"country should reference Norway: {parsed['country']!r}" + if "oslo" not in parsed["capital"].lower(): + return False, f"capital should be Oslo: {parsed['capital']!r}" + if not isinstance(parsed["population"], int) or parsed["population"] < 1_000_000: + return False, f"population implausible: {parsed['population']!r}" + if not isinstance(parsed["climate_summary"], str) or not parsed["climate_summary"]: + return False, "climate_summary must be a non-empty string" + hls = parsed["highlights"] + if not isinstance(hls, list) or not (2 <= len(hls) <= 5): + return False, f"highlights length out of range: {hls!r}" + if not all(isinstance(h, str) and h for h in hls): + return False, "each highlight must be a non-empty string" + if not isinstance(parsed["suitable_for_tourism"], bool): + return False, f"suitable_for_tourism must be bool: {parsed['suitable_for_tourism']!r}" + return True, ( + f"tools={names}; report for {parsed['country']} " + f"(pop {parsed['population']}, {len(hls)} highlights)" + ) + + +# --------------------------------------------------------------------------- +# All test cases +# --------------------------------------------------------------------------- + +ALL_TEST_CASES = [ + BOOK_TEST_CASE, + SENTIMENT_TEST_CASE, + RECIPE_TEST_CASE, + SHOP_COMPARISON_TEST_CASE, + COUNTRY_REPORT_TEST_CASE, +] + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Test llama-server structured-output capability." + ) + parser.add_argument("--host", default="localhost") + parser.add_argument("--port", default=8080, type=int) + parser.add_argument( + "--no-stream", action="store_true", help="Disable streaming mode tests" + ) + parser.add_argument( + "--stream-only", action="store_true", help="Only run streaming mode tests" + ) + parser.add_argument( + "--test", + help="Run only the test whose name contains this substring (case-insensitive)", + ) + args = parser.parse_args() + + url = f"http://{args.host}:{args.port}/v1/chat/completions" + print_info(f"Testing server at {url}") + + modes: list[bool] = [] + if not args.stream_only: + modes.append(False) + if not args.no_stream: + modes.append(True) + + cases: list[dict] = ALL_TEST_CASES + if args.test: + name_filter = args.test.lower() + cases = [c for c in cases if name_filter in str(c["name"]).lower()] + if not cases: + print_fail(f"No test cases matched '{args.test}'") + sys.exit(1) + + total = 0 + passed = 0 + for stream in modes: + for case in cases: + total += 1 + if run_test(url, case, stream=stream): + passed += 1 + + color = GREEN if passed == total else RED + _print(f"\n{BOLD}{color}{'─' * 60}{RESET}") + _print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}") + _print(f"{BOLD}{color}{'─' * 60}{RESET}\n") + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + main() diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh index 36c908da74e..27459df241b 100755 --- a/scripts/snapdragon/adb/run-bench.sh +++ b/scripts/snapdragon/adb/run-bench.sh @@ -23,10 +23,10 @@ verbose= [ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v" profile= -[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF" cli_opts="$cli_opts -v" opmask= -[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" +[ "$OPSTAGE" != "" ] && opmask="GGML_HEXAGON_OPSTAGE=$OPSTAGE" nhvx= [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" diff --git a/scripts/snapdragon/adb/run-cli.sh b/scripts/snapdragon/adb/run-cli.sh index 901d7eff13f..e1f0ac0eb8e 100755 --- a/scripts/snapdragon/adb/run-cli.sh +++ b/scripts/snapdragon/adb/run-cli.sh @@ -28,10 +28,10 @@ sched= [ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" profile= -[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF" cli_opts="$cli_opts -v" opmask= -[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" +[ "$OPSTAGE" != "" ] && opmask="GGML_HEXAGON_OPSTAGE=$OPSTAGE" nhvx= [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh index f7290825ad5..7b84106dc83 100755 --- a/scripts/snapdragon/adb/run-completion.sh +++ b/scripts/snapdragon/adb/run-completion.sh @@ -28,10 +28,10 @@ sched= [ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" profile= -[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF" cli_opts="$cli_opts -v" opmask= -[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" +[ "$OPSTAGE" != "" ] && opmask="GGML_HEXAGON_OPSTAGE=$OPSTAGE" nhvx= [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" diff --git a/scripts/snapdragon/adb/run-mtmd.sh b/scripts/snapdragon/adb/run-mtmd.sh index 0c1cf892800..38467beba3d 100755 --- a/scripts/snapdragon/adb/run-mtmd.sh +++ b/scripts/snapdragon/adb/run-mtmd.sh @@ -37,10 +37,10 @@ sched= [ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" profile= -[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF" opmask= -[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" +[ "$OPSTAGE" != "" ] && opmask="GGML_HEXAGON_OPSTAGE=$OPSTAGE" nhvx= [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" diff --git a/scripts/snapdragon/adb/run-tool.sh b/scripts/snapdragon/adb/run-tool.sh index 70ed407e87b..27cbb2b6d05 100755 --- a/scripts/snapdragon/adb/run-tool.sh +++ b/scripts/snapdragon/adb/run-tool.sh @@ -25,10 +25,10 @@ sched= [ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" profile= -[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF" opmask= -[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" +[ "$OPSTAGE" != "" ] && opmask="GGML_HEXAGON_OPSTAGE=$OPSTAGE" nhvx= [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" diff --git a/scripts/snapdragon/ggml-hexagon-profile.py b/scripts/snapdragon/ggml-hexagon-profile.py new file mode 100755 index 00000000000..3edaacd2749 --- /dev/null +++ b/scripts/snapdragon/ggml-hexagon-profile.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 + +import sys +import os +import re +import argparse +import statistics +import logging + +from collections import defaultdict + +# Mapping of cli-friendly names to (internal_data_key, Display Header, numeric_sort_key) +COL_MAP = { + "op": ("op", "Op", "op"), + "dims": ("dims", "Dims", "dims"), + "dtypes": ("dtypes", "DTypes", "dtypes"), + "count": ("count", "Count", "_sort_count"), + "max-usec": ("max_usec", "Max usec", "_sort_max_usec"), + "avg-usec": ("avg_usec", "Avg usec", "_sort_avg_usec"), + "max-cycles": ("max_cycles", "Max Cycles", "_sort_max_cycles"), + "avg-cycles": ("avg_cycles", "Avg Cycles", "_sort_avg_cycles"), + "max-pmu": ("max_pmu", "Max PMU", "_sort_max_pmu"), + "avg-pmu": ("avg_pmu", "Avg PMU", "_sort_avg_pmu"), +} + +op_pattern = re.compile( + r"profile-op\s+(?P[A-Z_0-9]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+.*?\s+usec\s+(?P\d+)\s+cycles\s+(?P\d+)(?:\s+pmu\s+\[(?P[\d,\s]+)\])?" +) + +logger = logging.getLogger("ggml-hexagon-profile") + + +def parse_log(file_path, pmu_index=None): + try: + if file_path != "-": + f = open(file_path, 'r', encoding='utf-8', errors='ignore') + else: + f = os.fdopen(0, 'r', encoding='utf-8', errors='ignore') + except FileNotFoundError: + logger.error(f"file '{file_path}' not found.") + sys.exit(1) + + all_ops = [] + for line in f: + match = op_pattern.search(line) + if not match: continue + + pmu_raw = match.group('pmu') + pmu_val = None + if pmu_raw and pmu_index is not None: + try: + pmu_list = [int(x.strip()) for x in pmu_raw.split(',')] + if len(pmu_list) > pmu_index: + pmu_val = pmu_list[pmu_index] + except (ValueError, IndexError): + pmu_val = None + + all_ops.append({ + 'name': match.group('op_name'), + 'dims': match.group('dims').strip(), + 'types': match.group('types').strip(), + 'usec': int(match.group('usec')), + 'cycles': int(match.group('cycles')), + 'pmu_val': pmu_val + }) + + f.close() + + return all_ops + + +def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None): + if not ops: + logger.info("No valid records found.") + return + + grouped = defaultdict(list) + for op in ops: + key = (op['name'], op['dims'], op['types']) + grouped[key].append(op) + + group_stats = [] + for (name, dims, types), group_ops in grouped.items(): + usecs = [o['usec'] for o in group_ops] + cycles = [o['cycles'] for o in group_ops] + pmu_vals = [o['pmu_val'] for o in group_ops if o['pmu_val'] is not None] + + group_stats.append({ + 'op': name, + 'dims': dims, + 'dtypes': types, + 'count': str(len(group_ops)), + 'max_usec': str(max(usecs)), + 'avg_usec': f"{statistics.mean(usecs):.2f}", + 'max_cycles': str(max(cycles)), + 'avg_cycles': f"{statistics.mean(cycles):.2f}", + 'max_pmu': str(max(pmu_vals)) if pmu_vals else "0", + 'avg_pmu': f"{statistics.mean(pmu_vals):.2f}" if pmu_vals else "0.00", + # Numeric values for accurate sorting + '_sort_count': len(group_ops), + '_sort_max_usec': max(usecs), + '_sort_avg_usec': statistics.mean(usecs), + '_sort_max_cycles': max(cycles), + '_sort_avg_cycles': statistics.mean(cycles), + '_sort_max_pmu': max(pmu_vals) if pmu_vals else 0, + '_sort_avg_pmu': statistics.mean(pmu_vals) if pmu_vals else 0 + }) + + # Sorting logic + actual_sort_key = COL_MAP[sort_col][2] + # We sort numeric fields descending, strings (op/dims) ascending + is_numeric = actual_sort_key.startswith("_") or actual_sort_key == "count" + sorted_groups = sorted(group_stats, key=lambda x: x[actual_sort_key], reverse=is_numeric)[:top_n] + + # Define initial column order + active_cols = ["op", "dims", "dtypes"] + if pmu_name: + active_cols += ["max-pmu", "avg-pmu"] + active_cols += ["max-usec", "avg-usec", "max-cycles", "avg-cycles", "count"] + + final_headers, final_keys, final_widths = [], [], [] + + for col_name in active_cols: + data_key, header_text, _ = COL_MAP[col_name] + if "pmu" in col_name and pmu_name: + header_text = header_text.replace("PMU", pmu_name) + + natural_width = max([len(row[data_key]) for row in sorted_groups] + [len(header_text)]) + target_width = width_overrides.get(col_name, natural_width) + + if target_width == 0: + continue + + final_headers.append(header_text) + final_keys.append(data_key) + final_widths.append(target_width) + + # Print Report + logger.info(f"\n# Profile Report (Top {top_n} Ops sorted by {sort_col})\n") + header_line = "| " + " | ".join(f"{h:<{final_widths[i]}}" for i, h in enumerate(final_headers)) + " |" + sep_line = "| " + " | ".join("-" * final_widths[i] for i in range(len(final_headers))) + " |" + logger.info(header_line) + logger.info(sep_line) + + for group in sorted_groups: + row_vals = [] + for i, key in enumerate(final_keys): + val = group[key] + if len(val) > final_widths[i]: + val = val[:final_widths[i] - 3] + "..." + row_vals.append(f"{val:<{final_widths[i]}}") + logger.info("| " + " | ".join(row_vals) + " |") + + +def main(): + parser = argparse.ArgumentParser(description="Post-process Op profile info.") + parser.add_argument("logfile") + parser.add_argument("-n", "--top", type=int, default=100) + parser.add_argument("--sort", type=str, default="max-usec", choices=list(COL_MAP.keys())) + parser.add_argument("--pmu-index", type=int) + parser.add_argument("--pmu-name", type=str) + parser.add_argument("--width", action='append', default=['dims:40'], help="Override column width, e.g. --width dims:50") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format='%(message)s') + + # Sort validation: can't sort by PMU if index isn't provided + if "pmu" in args.sort and args.pmu_index is None: + logger.error(f"Cannot sort by '{args.sort}' without --pmu-index.") + sys.exit(1) + + overrides = {} + if args.width: + for w in args.width: + try: + name, val = w.split(':') + overrides[name.lower()] = int(val) + except ValueError: + logger.warning(f"Invalid width format '{w}'") + + final_pmu_name = (args.pmu_name or f"#{args.pmu_index}") if args.pmu_index is not None else None + ops = parse_log(args.logfile, pmu_index=args.pmu_index) + generate_report(ops, args.top, overrides, args.sort, pmu_name=final_pmu_name) + + +if __name__ == "__main__": + main() diff --git a/scripts/snapdragon/qdc/readme.md b/scripts/snapdragon/qdc/readme.md deleted file mode 100644 index b92cf243aaa..00000000000 --- a/scripts/snapdragon/qdc/readme.md +++ /dev/null @@ -1 +0,0 @@ -This directory includes pytest based scripts for running CI jobs on Qualcomm Device Cloud (QDC). diff --git a/scripts/snapdragon/qdc/requirements.txt b/scripts/snapdragon/qdc/requirements.txt index f04bd682ea0..5e0f85917e3 100644 --- a/scripts/snapdragon/qdc/requirements.txt +++ b/scripts/snapdragon/qdc/requirements.txt @@ -8,12 +8,9 @@ iniconfig==2.1.0 outcome==1.3.0.post0 packaging==25.0 pluggy==1.6.0 -Pygments==2.19.2 PySocks==1.7.1 pytest==8.4.2 -pytest-dependency==0.6.0 selenium==4.36.0 -setuptools==80.9.0 sniffio==1.3.1 sortedcontainers==2.4.0 tomli==2.3.0 diff --git a/scripts/snapdragon/qdc/run_qdc_jobs.py b/scripts/snapdragon/qdc/run_qdc_jobs.py new file mode 100644 index 00000000000..b4eede3d019 --- /dev/null +++ b/scripts/snapdragon/qdc/run_qdc_jobs.py @@ -0,0 +1,401 @@ +"""Run llama.cpp Hexagon Android tests in a single QDC Appium job. + +Bundles test scripts into one artifact and submits a single QDC job: + + 1. run_bench_tests_posix.py — llama-cli and llama-bench on CPU / GPU / NPU + (from scripts/snapdragon/qdc/) + +Results are written to $GITHUB_STEP_SUMMARY when set (GitHub Actions). + +Prerequisites: + pip install /path/to/qualcomm_device_cloud_sdk*.whl + +Required environment variables: + QDC_API_KEY API key from QDC UI -> Users -> Settings -> API Keys + +Usage: + python run_qdc_jobs.py \\ + --pkg-dir pkg-snapdragon/llama.cpp \\ + --model-url https://.../Llama-3.2-1B-Instruct-Q4_0.gguf \\ + --device SM8750 +""" + +from __future__ import annotations + +import argparse +import logging +import os +import re +import shutil +import sys +import tempfile +import time +import xml.etree.ElementTree as ET +from dataclasses import dataclass, field +from pathlib import Path + +from qualcomm_device_cloud_sdk.api import qdc_api # ty: ignore[unresolved-import] +from qualcomm_device_cloud_sdk.logging import configure_logging # ty: ignore[unresolved-import] +from qualcomm_device_cloud_sdk.models import ArtifactType, JobMode, JobState, JobSubmissionParameter, JobType, TestFramework # ty: ignore[unresolved-import] + +configure_logging(level=logging.INFO, handlers=[logging.StreamHandler()]) +log = logging.getLogger(__name__) + +POLL_INTERVAL = 30 +JOB_TIMEOUT = 3600 +LOG_UPLOAD_TIMEOUT = 600 +CAPACITY_TIMEOUT = 1800 +CAPACITY_POLL = 60 +MAX_CONCURRENT_JOBS = 5 +TERMINAL_STATES = {JobState.COMPLETED, JobState.CANCELED} +NON_TERMINAL_STATES = {JobState.DISPATCHED, JobState.RUNNING, JobState.SETUP, JobState.SUBMITTED} + +_SCRIPTS_DIR = Path(__file__).parent +_TESTS_DIR = _SCRIPTS_DIR / "tests" +_RUN_BENCH = _TESTS_DIR / "run_bench_tests_posix.py" +_RUN_BACKEND_OPS = _TESTS_DIR / "run_backend_ops_posix.py" +_UTILS = _TESTS_DIR / "utils.py" +_CONFTEST = _TESTS_DIR / "conftest.py" +_REQUIREMENTS = _SCRIPTS_DIR / "requirements.txt" + +_PYTEST_LINE_RE = re.compile( + r"(?:[\w/]+\.py::)?(?:\w+::)?([\w\[\].-]+)\s+(PASSED|FAILED|ERROR|SKIPPED)" +) +_EXCLUDED_LOGS = {"qdc_android_whole_host-000.log", "qdc_kernel_host-000.log"} +_NON_TERMINAL_STATE_VALUES = {s.value for s in NON_TERMINAL_STATES} + + +@dataclass +class JobResult: + passed: bool + tests: dict[str, bool] = field(default_factory=dict) + raw_logs: dict[str, str] = field(default_factory=dict) + failure_details: dict[str, str] = field(default_factory=dict) + + +def build_artifact_zip( + pkg_dir: Path, + stage_dir: Path, + *, + test_mode: str = "bench", + model_url: str | None = None, +) -> Path: + """Bundle everything into a single QDC artifact zip. + + Zip structure (extracted by QDC to /qdc/appium/ on the runner): + llama_cpp_bundle/ installed package (adb pushed to /data/local/tmp/) + tests/ + utils.py shared helpers (paths, run_adb_command, …) + conftest.py shared pytest fixtures (driver) + test_bench_posix.py bench + cli tests (<> substituted) + AND/OR + test_backend_ops_posix.py test-backend-ops -b HTP0 + requirements.txt + """ + shutil.copytree(pkg_dir, stage_dir / "llama_cpp_bundle") + + tests_dir = stage_dir / "tests" + tests_dir.mkdir() + + shutil.copy(_UTILS, tests_dir / "utils.py") + shutil.copy(_CONFTEST, tests_dir / "conftest.py") + + if test_mode in ("bench", "all"): + assert model_url is not None, "--model-url is required for bench/all test modes" + (tests_dir / "test_bench_posix.py").write_text( + _RUN_BENCH.read_text().replace("<>", model_url) + ) + if test_mode in ("backend-ops", "all"): + shutil.copy(_RUN_BACKEND_OPS, tests_dir / "test_backend_ops_posix.py") + + shutil.copy(_REQUIREMENTS, stage_dir / "requirements.txt") + (stage_dir / "pytest.ini").write_text("[pytest]\naddopts = --junitxml=results.xml\n") + + zip_base = str(stage_dir / "artifact") + shutil.make_archive(zip_base, "zip", stage_dir) + return Path(f"{zip_base}.zip") + + +def wait_for_job(client, job_id: str, timeout: int) -> str: + elapsed = 0 + while elapsed < timeout: + raw = qdc_api.get_job_status(client, job_id) + try: + status = JobState(raw) + except ValueError: + status = raw + if status in TERMINAL_STATES: + return raw.lower() + log.info("Job %s: %s", job_id, raw) + time.sleep(POLL_INTERVAL) + elapsed += POLL_INTERVAL + raise TimeoutError(f"Job {job_id} did not finish within {timeout}s") + + +def wait_for_log_upload(client, job_id: str) -> None: + elapsed = 0 + while elapsed <= LOG_UPLOAD_TIMEOUT: + status = (qdc_api.get_job_log_upload_status(client, job_id) or "").lower() + if status in {"completed", "failed"}: + return + log.info("Waiting for log upload (status=%s) ...", status) + time.sleep(POLL_INTERVAL) + elapsed += POLL_INTERVAL + log.warning("Timed out waiting for log upload after %ds", LOG_UPLOAD_TIMEOUT) + + +def wait_for_capacity(client, max_jobs: int = MAX_CONCURRENT_JOBS) -> None: + """Block until the user's active (non-terminal) QDC job count is below max_jobs.""" + elapsed = 0 + while elapsed < CAPACITY_TIMEOUT: + jobs_page = qdc_api.get_jobs_list(client, page_number=0, page_size=50) + if jobs_page is None: + log.warning("Could not retrieve job list; proceeding without capacity check") + return + items = getattr(jobs_page, "data", []) or [] + active = sum(1 for j in items if getattr(j, "state", None) in _NON_TERMINAL_STATE_VALUES) + if active < max_jobs: + log.info("Active QDC jobs: %d / %d — proceeding", active, max_jobs) + return + log.info("Active QDC jobs: %d / %d — waiting %ds ...", active, max_jobs, CAPACITY_POLL) + time.sleep(CAPACITY_POLL) + elapsed += CAPACITY_POLL + log.warning("Capacity wait timed out after %ds; proceeding anyway", CAPACITY_TIMEOUT) + + +def _parse_junit_xml(content: str) -> tuple[dict[str, bool], dict[str, str]]: + try: + root = ET.fromstring(content) + except ET.ParseError: + return {}, {} + results: dict[str, bool] = {} + failures: dict[str, str] = {} + for tc in root.iter("testcase"): + name = tc.get("name", "") + if classname := tc.get("classname", ""): + name = f"{classname}.{name}" + failure_el = tc.find("failure") + if failure_el is None: + failure_el = tc.find("error") + results[name] = failure_el is None + if failure_el is not None: + parts = [failure_el.get("message", ""), failure_el.text or ""] + failures[name] = "\n".join(p for p in parts if p).strip() + return results, failures + + +def _parse_pytest_output(content: str) -> dict[str, bool]: + results: dict[str, bool] = {} + for m in _PYTEST_LINE_RE.finditer(content): + results[m.group(1)] = m.group(2) == "PASSED" + return results + + +def fetch_logs_and_parse_tests( + client, job_id: str +) -> tuple[dict[str, bool], dict[str, str], dict[str, str]]: + """Returns (test_results, raw_logs, failure_details).""" + log_files = qdc_api.get_job_log_files(client, job_id) + if not log_files: + log.warning("No log files returned for job %s", job_id) + return {}, {}, {} + + test_results: dict[str, bool] = {} + pytest_fallback: dict[str, bool] = {} + raw_logs: dict[str, str] = {} + failure_details: dict[str, str] = {} + + with tempfile.TemporaryDirectory() as tmpdir: + for lf in log_files: + log.info("Downloading log file: %s", lf.filename) + zip_path = os.path.join(tmpdir, "log.zip") + qdc_api.download_job_log_files(client, lf.filename, zip_path) + try: + shutil.unpack_archive(zip_path, tmpdir, "zip") + except Exception as e: + log.warning("Could not unpack %s as zip: %s", lf.filename, e) + + for root_dir, _, files in os.walk(tmpdir): + for fname in sorted(files): + fpath = os.path.join(root_dir, fname) + content = Path(fpath).read_text(errors="replace") + if fname.endswith(".xml"): + results, failures = _parse_junit_xml(content) + test_results.update(results) + failure_details.update(failures) + elif fname.endswith(".log"): + if fname in _EXCLUDED_LOGS: + continue + log.info("--- %s ---", fname) + log.info("%s", content) + raw_logs[fname] = content + pytest_fallback.update(_parse_pytest_output(content)) + + return (test_results if test_results else pytest_fallback), raw_logs, failure_details + + +def write_summary(result: JobResult, title: str = "QDC Test Results") -> None: + summary_path = os.environ.get("GITHUB_STEP_SUMMARY") + if not summary_path: + return + + icon = "✅" if result.passed else "❌" + + lines = [ + f"## {title}\n", + f"Overall: {icon} {'PASSED' if result.passed else 'FAILED'}\n", + ] + reportable = {n: ok for n, ok in result.tests.items() if "test_install" not in n} + if reportable: + lines += ["| Test | Result |", "| ---- | ------ |"] + for name, ok in reportable.items(): + lines.append(f"| `{name}` | {'✅' if ok else '❌'} |") + passed_n = sum(1 for v in reportable.values() if v) + failed_n = sum(1 for v in reportable.values() if not v) + lines += ["", f"**{passed_n} passed, {failed_n} failed**"] + else: + lines.append("_No per-test data available._") + + failed_names = [n for n, ok in reportable.items() if not ok] + if failed_names: + lines += ["", "### Failures"] + for name in failed_names: + detail = result.failure_details.get(name) + if detail: + lines += [ + f"
{name}", + "", + "```", + detail, + "```", + "", + "
", + ] + + if result.raw_logs: + lines += ["", "### Raw Logs"] + for fname, content in sorted(result.raw_logs.items()): + lines += [ + f"
{fname}", + "", + "```", + content.rstrip(), + "```", + "", + "
", + ] + + with open(summary_path, "a") as f: + f.write("\n".join(lines) + "\n") + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--pkg-dir", required=True, type=Path, + help="Installed llama.cpp package directory (contains bin/ and lib/)") + p.add_argument("--model-url", + help="Direct URL to the GGUF model file (required for --test bench)") + p.add_argument("--device", required=True, + help="QDC chipset name, e.g. SM8750") + p.add_argument("--test", choices=["bench", "backend-ops", "all"], default="bench", + help="Test suite to run (default: bench)") + p.add_argument("--job-timeout", type=int, default=JOB_TIMEOUT, metavar="SECONDS", + help=f"Max seconds to wait for job completion (default: {JOB_TIMEOUT})") + args = p.parse_args() + if args.test in ("bench", "all") and not args.model_url: + p.error("--model-url is required when --test bench or --test all") + return args + + +def main() -> int: + args = parse_args() + + api_key = os.environ.get("QDC_API_KEY") + if not api_key: + log.error("QDC_API_KEY environment variable must be set") + return 1 + if not args.pkg_dir.is_dir(): + log.error("--pkg-dir %s does not exist", args.pkg_dir) + return 1 + + client = qdc_api.get_public_api_client_using_api_key( + api_key_header=api_key, + app_name_header="llama-cpp-ci", + on_behalf_of_header="llama-cpp-ci", + client_type_header="Python", + ) + + target_id = qdc_api.get_target_id(client, args.device) + if target_id is None: + log.error("Could not find QDC target for device %r", args.device) + return 1 + + with tempfile.TemporaryDirectory() as tmpdir: + log.info("Building artifact ...") + zip_path = build_artifact_zip( + args.pkg_dir, Path(tmpdir), + test_mode=args.test, model_url=args.model_url, + ) + log.info("Uploading artifact (%d MB) ...", zip_path.stat().st_size // 1_000_000) + artifact_id = qdc_api.upload_file(client, str(zip_path), ArtifactType.TESTSCRIPT) + + if artifact_id is None: + log.error("Artifact upload failed") + return 1 + + wait_for_capacity(client) + + job_id = qdc_api.submit_job( + public_api_client=client, + target_id=target_id, + job_name="llama.cpp Hexagon tests", + external_job_id=None, + job_type=JobType.AUTOMATED, + job_mode=JobMode.APPLICATION, + timeout=max(1, args.job_timeout // 60), + test_framework=TestFramework.APPIUM, + entry_script=None, + job_artifacts=[artifact_id], + monkey_events=None, + monkey_session_timeout=None, + job_parameters=[JobSubmissionParameter.WIFIENABLED], + ) + if job_id is None: + log.error("Job submission failed") + return 1 + log.info("Job submitted: %s (device=%s)", job_id, args.device) + + try: + job_status = wait_for_job(client, job_id, timeout=args.job_timeout) + except TimeoutError as e: + log.error("%s", e) + write_summary(JobResult(passed=False, tests={}), title=f"QDC Job Timed Out ({args.device})") + return 1 + log.info("Job %s finished: %s", job_id, job_status) + + wait_for_log_upload(client, job_id) + tests, raw_logs, failure_details = fetch_logs_and_parse_tests(client, job_id) + + passed = job_status == JobState.COMPLETED.value.lower() + if tests: + passed = passed and all(tests.values()) + if not passed: + log.error("Job did not complete successfully or tests failed (status=%s)", job_status) + + result = JobResult(passed=passed, tests=tests, raw_logs=raw_logs, failure_details=failure_details) + if args.test == "backend-ops": + title = f"Backend Ops — HTP0 ({args.device})" + elif args.test == "all": + title = f"QDC Tests ({args.device})" + else: + title = f"QDC Test Results ({args.device})" + write_summary(result, title=title) + + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/snapdragon/qdc/tests/conftest.py b/scripts/snapdragon/qdc/tests/conftest.py new file mode 100644 index 00000000000..0fc5b3e5fa7 --- /dev/null +++ b/scripts/snapdragon/qdc/tests/conftest.py @@ -0,0 +1,20 @@ +"""Shared pytest fixtures for QDC on-device test runners.""" + +import os + +import pytest +from appium import webdriver + +from utils import options, write_qdc_log + + +@pytest.fixture(scope="session", autouse=True) +def driver(): + return webdriver.Remote(command_executor="http://127.0.0.1:4723/wd/hub", options=options) + + +def pytest_sessionfinish(session, exitstatus): + xml_path = getattr(session.config.option, "xmlpath", None) or "results.xml" + if os.path.exists(xml_path): + with open(xml_path) as f: + write_qdc_log("results.xml", f.read()) diff --git a/scripts/snapdragon/qdc/tests/run_backend_ops_posix.py b/scripts/snapdragon/qdc/tests/run_backend_ops_posix.py new file mode 100644 index 00000000000..958fc074762 --- /dev/null +++ b/scripts/snapdragon/qdc/tests/run_backend_ops_posix.py @@ -0,0 +1,41 @@ +""" +On-device test-backend-ops runner for llama.cpp (HTP0 backend). + +Executed by QDC's Appium test framework on the QDC runner. +The runner has ADB access to the allocated device. +""" + +import os +import sys + +import pytest + +from utils import BIN_PATH, CMD_PREFIX, push_bundle_if_needed, run_adb_command, write_qdc_log + + +@pytest.fixture(scope="session", autouse=True) +def install(driver): + push_bundle_if_needed(f"{BIN_PATH}/test-backend-ops") + + +@pytest.mark.parametrize("type_a", ["mxfp4", "fp16", "q4_0"]) +def test_backend_ops_htp0(type_a): + cmd = f"{CMD_PREFIX} GGML_HEXAGON_HOSTBUF=0 GGML_HEXAGON_EXPERIMENTAL=1 {BIN_PATH}/test-backend-ops -b HTP0 -o MUL_MAT" + if type_a == "q4_0": + cmd += r' -p "^(?=.*type_a=q4_0)(?!.*type_b=f32,m=576,n=512,k=576).*$"' + else: + cmd += f" -p type_a={type_a}" + result = run_adb_command( + cmd, + check=False, + ) + write_qdc_log(f"backend_ops_{type_a}.log", result.stdout or "") + assert result.returncode == 0, f"test-backend-ops type_a={type_a} failed (exit {result.returncode})" + + +if __name__ == "__main__": + ret = pytest.main(["-s", "--junitxml=results.xml", os.path.realpath(__file__)]) + if os.path.exists("results.xml"): + with open("results.xml") as f: + write_qdc_log("results.xml", f.read()) + sys.exit(ret) diff --git a/scripts/snapdragon/qdc/tests/run_bench_tests_posix.py b/scripts/snapdragon/qdc/tests/run_bench_tests_posix.py new file mode 100644 index 00000000000..44802c3136a --- /dev/null +++ b/scripts/snapdragon/qdc/tests/run_bench_tests_posix.py @@ -0,0 +1,76 @@ +""" +On-device bench and completion test runner for llama.cpp (CPU, GPU, NPU backends). + +Executed by QDC's Appium test framework on the QDC runner. +The runner has ADB access to the allocated device. + +Placeholders replaced at artifact creation time by run_qdc_jobs.py: + <> Direct URL to the GGUF model file (downloaded on-device via curl) +""" + +import os +import subprocess +import sys + +import pytest + +from utils import BIN_PATH, CMD_PREFIX, push_bundle_if_needed, run_adb_command, write_qdc_log + +MODEL_PATH = "/data/local/tmp/model.gguf" +PROMPT = "What is the capital of France?" +CLI_OPTS = "--batch-size 128 -n 128 -no-cnv --seed 42" + + +@pytest.fixture(scope="session", autouse=True) +def install(driver): + push_bundle_if_needed(f"{BIN_PATH}/llama-cli") + + # Skip model download if already present + check = subprocess.run( + ["adb", "shell", f"ls {MODEL_PATH}"], + text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + if check.returncode != 0: + run_adb_command(f'curl -L -J --output {MODEL_PATH} "<>"') + + +@pytest.mark.parametrize("device,extra_flags", [ + pytest.param("none", "-ctk q8_0 -ctv q8_0", id="cpu"), + pytest.param("GPUOpenCL", "", id="gpu"), + pytest.param("HTP0", "-ctk q8_0 -ctv q8_0", id="npu"), +]) +def test_llama_completion(device, extra_flags): + result = run_adb_command( + f'{CMD_PREFIX} {BIN_PATH}/llama-completion' + f' -m {MODEL_PATH} --device {device} -ngl 99 -t 4 {CLI_OPTS} {extra_flags} -fa on' + f' -p "{PROMPT}"', + check=False, + ) + write_qdc_log(f"llama_completion_{device}.log", result.stdout or "") + assert result.returncode == 0, f"llama-completion {device} failed (exit {result.returncode})" + + +_DEVICE_LOG_NAME = {"none": "cpu", "GPUOpenCL": "gpu", "HTP0": "htp"} + + +@pytest.mark.parametrize("device", [ + pytest.param("none", id="cpu"), + pytest.param("GPUOpenCL", id="gpu"), + pytest.param("HTP0", id="npu"), +]) +def test_llama_bench(device): + result = run_adb_command( + f"{CMD_PREFIX} {BIN_PATH}/llama-bench" + f" -m {MODEL_PATH} --device {device} -ngl 99 --batch-size 128 -t 4 -p 128 -n 32", + check=False, + ) + write_qdc_log(f"llama_bench_{_DEVICE_LOG_NAME[device]}.log", result.stdout or "") + assert result.returncode == 0, f"llama-bench {device} failed (exit {result.returncode})" + + +if __name__ == "__main__": + ret = pytest.main(["-s", "--junitxml=results.xml", os.path.realpath(__file__)]) + if os.path.exists("results.xml"): + with open("results.xml") as f: + write_qdc_log("results.xml", f.read()) + sys.exit(ret) diff --git a/scripts/snapdragon/qdc/tests/test_bench.py b/scripts/snapdragon/qdc/tests/test_bench.py deleted file mode 100644 index 651ab5b7172..00000000000 --- a/scripts/snapdragon/qdc/tests/test_bench.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import subprocess -import sys - -tmp_path='/data/local/tmp' -pkg_path=f'{tmp_path}/llama.cpp' -lib_path=f'{pkg_path}/lib' -bin_path=f'{pkg_path}/bin' - -model='../gguf/Llama-3.2-1B-Instruct-Q4_0.gguf' -cli_pref=f'cd {pkg_path} && LD_LIBRARY_PATH={lib_path} ADSP_LIBRARY_PATH={lib_path} {bin_path}' - - -def run_cmd(cmd): - p = subprocess.run(cmd, text = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) - sys.stdout.write(p.stdout) - assert(p.returncode == 0) - - -@pytest.mark.dependency() -def test_install(): - run_cmd(['adb', 'push', 'llama.cpp', f'{tmp_path}']) - run_cmd(['adb', 'shell', f'chmod 755 {bin_path}/*']) - - -## Basic cli tests -def run_llama_cli(dev, opts): - prompt='what is the most popular cookie in the world?\nPlease provide a very brief bullet point summary.\nBegin your answer with **BEGIN**.' - opts = '--batch-size 128 -n 128 -no-cnv --seed 42 ' + opts - run_cmd(['adb', 'shell', f'{cli_pref}/llama-cli -m {model} --device {dev} -ngl 99 -t 4 {opts} -p "{prompt}"']) - - -@pytest.mark.dependency(depends=['test_install']) -def test_llama_cli_cpu(): - run_llama_cli('none', '-ctk q8_0 -ctv q8_0 -fa on') - - -@pytest.mark.dependency(depends=['test_install']) -def test_llama_cli_gpu(): - run_llama_cli('GPUOpenCL', '-fa on') - - -@pytest.mark.dependency(depends=['test_install']) -def test_llama_cli_npu(): - run_llama_cli('HTP0', '-ctk q8_0 -ctv q8_0 -fa on') - - -## Basic bench tests -def run_llama_bench(dev): - run_cmd(['adb', 'shell', f'{cli_pref}/llama-bench -m {model} --device {dev} -ngl 99 --batch-size 128 -t 4 -p 128 -n 32']) - - -@pytest.mark.dependency(depends=['test_install']) -def test_llama_bench_cpu(): - run_llama_bench('none') - - -def test_llama_bench_gpu(): - run_llama_bench('GPUOpenCL') - - -def test_llama_bench_npu(): - run_llama_bench('HTP0') diff --git a/scripts/snapdragon/qdc/tests/utils.py b/scripts/snapdragon/qdc/tests/utils.py new file mode 100644 index 00000000000..00f0f1b2f91 --- /dev/null +++ b/scripts/snapdragon/qdc/tests/utils.py @@ -0,0 +1,93 @@ +"""Shared helpers for QDC on-device test runners.""" + +import logging +import os +import subprocess +import tempfile + +from appium.options.common import AppiumOptions + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# On-device paths +# --------------------------------------------------------------------------- + +BUNDLE_PATH = "/data/local/tmp/llama_cpp_bundle" +QDC_LOGS_PATH = "/data/local/tmp/QDC_logs" +LIB_PATH = f"{BUNDLE_PATH}/lib" +BIN_PATH = f"{BUNDLE_PATH}/bin" +ENV_PREFIX = ( + f"export LD_LIBRARY_PATH={LIB_PATH} && " + f"export ADSP_LIBRARY_PATH={LIB_PATH} && " + f"chmod +x {BIN_PATH}/* &&" +) +CMD_PREFIX = f"cd {BUNDLE_PATH} && {ENV_PREFIX}" + +# --------------------------------------------------------------------------- +# Appium session options +# --------------------------------------------------------------------------- + +options = AppiumOptions() +options.set_capability("automationName", "UiAutomator2") +options.set_capability("platformName", "Android") +options.set_capability("deviceName", os.getenv("ANDROID_DEVICE_VERSION")) + +# --------------------------------------------------------------------------- +# ADB helpers +# --------------------------------------------------------------------------- + + +def run_adb_command(cmd: str, *, check: bool = True) -> subprocess.CompletedProcess: + # Append exit-code sentinel because `adb shell` doesn't reliably propagate + # the on-device exit code (older ADB versions always return 0). + raw = subprocess.run( + ["adb", "shell", f"{cmd}; echo __RC__:$?"], + text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + stdout = raw.stdout + returncode = raw.returncode + if stdout: + lines = stdout.rstrip("\n").split("\n") + if lines and lines[-1].startswith("__RC__:"): + try: + returncode = int(lines[-1][7:]) + stdout = "\n".join(lines[:-1]) + "\n" + except ValueError: + pass + log.info("%s", stdout) + result = subprocess.CompletedProcess(raw.args, returncode, stdout=stdout) + if check: + assert returncode == 0, f"Command failed (exit {returncode})" + return result + + +def write_qdc_log(filename: str, content: str) -> None: + """Push content as a log file to QDC_LOGS_PATH on the device for QDC log collection.""" + subprocess.run( + ["adb", "shell", f"mkdir -p {QDC_LOGS_PATH}"], + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f: + f.write(content) + tmp_path = f.name + try: + subprocess.run( + ["adb", "push", tmp_path, f"{QDC_LOGS_PATH}/{filename}"], + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + finally: + os.unlink(tmp_path) + + +def push_bundle_if_needed(check_binary: str) -> None: + """Push llama_cpp_bundle to the device if check_binary is not already present.""" + result = subprocess.run( + ["adb", "shell", f"ls {check_binary}"], + text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + if result.returncode != 0: + subprocess.run( + ["adb", "push", "/qdc/appium/llama_cpp_bundle/", "/data/local/tmp"], + text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) diff --git a/scripts/snapdragon/windows/run-bench.ps1 b/scripts/snapdragon/windows/run-bench.ps1 index 5a3a9074dfd..8bf6939d2c0 100644 --- a/scripts/snapdragon/windows/run-bench.ps1 +++ b/scripts/snapdragon/windows/run-bench.ps1 @@ -21,11 +21,11 @@ if ($null -ne $env:V) { } if ($null -ne $env:PROF) { - $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 + $env:GGML_HEXAGON_PROFILE=$env:PROF } -if ($null -ne $env:OPMASK) { - $env:GGML_HEXAGON_OPMASK=$env:OPMASK +if ($null -ne $env:OPSTAGE) { + $env:GGML_HEXAGON_OPSTAGE=$env:OPSTAGE } if ($null -ne $env:NHVX) { diff --git a/scripts/snapdragon/windows/run-cli.ps1 b/scripts/snapdragon/windows/run-cli.ps1 index c64aaf725cf..104452f9ba7 100644 --- a/scripts/snapdragon/windows/run-cli.ps1 +++ b/scripts/snapdragon/windows/run-cli.ps1 @@ -25,11 +25,11 @@ if ($null -ne $env:SCHED) { } if ($null -ne $env:PROF) { - $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 + $env:GGML_HEXAGON_PROFILE=$env:PROF } -if ($null -ne $env:OPMASK) { - $env:GGML_HEXAGON_OPMASK=$env:OPMASK +if ($null -ne $env:OPSTAGE) { + $env:GGML_HEXAGON_OPSTAGE=$env:OPSTAGE } if ($null -ne $env:NHVX) { diff --git a/scripts/snapdragon/windows/run-completion.ps1 b/scripts/snapdragon/windows/run-completion.ps1 index a896cd3524d..5841a82fa99 100644 --- a/scripts/snapdragon/windows/run-completion.ps1 +++ b/scripts/snapdragon/windows/run-completion.ps1 @@ -25,11 +25,11 @@ if ($null -ne $env:SCHED) { } if ($null -ne $env:PROF) { - $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 + $env:GGML_HEXAGON_PROFILE=$env:PROF } -if ($null -ne $env:OPMASK) { - $env:GGML_HEXAGON_OPMASK=$env:OPMASK +if ($null -ne $env:OPSTAGE) { + $env:GGML_HEXAGON_OPSTAGE=$env:OPSTAGE } if ($null -ne $env:NHVX) { diff --git a/scripts/snapdragon/windows/run-mtmd.ps1 b/scripts/snapdragon/windows/run-mtmd.ps1 index f230ac5a6b7..be817875142 100644 --- a/scripts/snapdragon/windows/run-mtmd.ps1 +++ b/scripts/snapdragon/windows/run-mtmd.ps1 @@ -34,11 +34,11 @@ if ($null -ne $env:SCHED) { } if ($null -ne $env:PROF) { - $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 + $env:GGML_HEXAGON_PROFILE=$env:PROF } -if ($null -ne $env:OPMASK) { - $env:GGML_HEXAGON_OPMASK=$env:OPMASK +if ($null -ne $env:OPSTAGE) { + $env:GGML_HEXAGON_OPSTAGE=$env:OPSTAGE } if ($null -ne $env:NHVX) { diff --git a/scripts/snapdragon/windows/run-tool.ps1 b/scripts/snapdragon/windows/run-tool.ps1 index 39edbfcf76c..15c880f2dbd 100644 --- a/scripts/snapdragon/windows/run-tool.ps1 +++ b/scripts/snapdragon/windows/run-tool.ps1 @@ -31,11 +31,11 @@ if ($null -ne $env:SCHED) { } if ($null -ne $env:PROF) { - $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 + $env:GGML_HEXAGON_PROFILE=$env:PROF } -if ($null -ne $env:OPMASK) { - $env:GGML_HEXAGON_OPMASK=$env:OPMASK +if ($null -ne $env:OPSTAGE) { + $env:GGML_HEXAGON_OPSTAGE=$env:OPSTAGE } if ($null -ne $env:NHVX) { diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index e154cc5c69b..de0140cfe24 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -49f84a924f6ea4fc2ef73dbbd8cc4d734b54bd6d +1c40d85a4dcfcd62176f649b8682433bb1a6caef diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 3f1e74f7cbc..ff1dd075303 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.40.0" +HTTPLIB_VERSION = "refs/tags/v0.43.1" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6904b9c1a64..3ab9dd4d505 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -109,6 +109,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, + { LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -126,6 +127,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_MISTRAL4, "mistral4" }, + { LLM_ARCH_EAGLE3, "eagle3" }, + { LLM_ARCH_DFLASH, "dflash" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_STEP35, "step35" }, @@ -250,6 +253,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" }, { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -282,6 +286,14 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_EAGLE3_EXTRACT_LAYERS, "%s.extract_layers" }, + { LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, "%s.norm_before_residual" }, + + { LLM_KV_DFLASH_TARGET_LAYER_IDS, "%s.target_layer_ids" }, + { LLM_KV_DFLASH_BLOCK_SIZE, "%s.block_size" }, + { LLM_KV_DFLASH_MASK_TOKEN_ID, "%s.mask_token_id" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, @@ -545,6 +557,13 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + // EAGLE-3 specific layers + { LLM_TENSOR_EAGLE3_HIDDEN_NORM, "blk.%d.hidden_norm" }, + { LLM_TENSOR_EAGLE3_FC, "fc" }, + { LLM_TENSOR_EAGLE3_D2T, "d2t" }, + // DFlash specific layers + { LLM_TENSOR_DFLASH_FC, "fc" }, + { LLM_TENSOR_DFLASH_HIDDEN_NORM, "hidden_norm" }, }; // declare information about the model weight tensors: @@ -765,6 +784,13 @@ static const std::map LLM_TENSOR_INFOS = { // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // EAGLE-3 tensors + {LLM_TENSOR_EAGLE3_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_EAGLE3_HIDDEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_EAGLE3_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + // DFlash tensors + {LLM_TENSOR_DFLASH_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DFLASH_HIDDEN_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index c4aabab7e0c..29f260eed52 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -113,6 +113,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, + LLM_ARCH_HUNYUAN_VL, LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, @@ -137,6 +138,8 @@ enum llm_arch { LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, + LLM_ARCH_EAGLE3, + LLM_ARCH_DFLASH, }; enum llm_kv { @@ -254,6 +257,7 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ALPHA, LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -324,6 +328,14 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_EAGLE3_EXTRACT_LAYERS, + LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, + LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, + + LLM_KV_DFLASH_TARGET_LAYER_IDS, + LLM_KV_DFLASH_BLOCK_SIZE, + LLM_KV_DFLASH_MASK_TOKEN_ID, + LLM_KV_SHORTCONV_L_CACHE, LLM_KV_XIELU_ALPHA_N, @@ -552,6 +564,11 @@ enum llm_tensor { LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_EAGLE3_FC, // eagle3: feature fusion layer + LLM_TENSOR_EAGLE3_HIDDEN_NORM, // eagle3: additional normalization layer + LLM_TENSOR_EAGLE3_D2T, // eagle3: draft to target vocabulary mapping + LLM_TENSOR_DFLASH_FC, + LLM_TENSOR_DFLASH_HIDDEN_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ee0c29235cd..e904db066fd 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -165,6 +165,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + cparams.eagle3_extract_enabled = false; + cparams.dflash_extract_enabled = false; + // initialized later cparams.pipeline_parallel = false; @@ -345,6 +348,15 @@ llama_context::llama_context( if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); } + // temp fix: DFlash encoder/decoder share one model_dft, keep the role on the context + dflash_decoder_ctx = model.arch == LLM_ARCH_DFLASH && params.target_model != nullptr; + // DFlash decoder: pre-fill cross with reservation size so build_inp_cross_embd + // uses cparams.n_ctx instead of hparams.n_ctx_train (which can cause OOM) + if (dflash_decoder_ctx) { + cross.n_embd = hparams.n_embd; + cross.n_enc = cparams.n_ctx; + cross.v_embd.resize(cross.n_embd * cross.n_enc, 0.0f); + } sched_reserve(); @@ -1168,7 +1180,78 @@ bool llama_context::set_adapter_cvec( return res; } +void llama_context::set_eagle3(const llama_model * model) { + // Initialize EAGLE3 feature extraction configuration + cparams.eagle3_extract_enabled = !!model; + if (!cparams.eagle3_extract_enabled) { + return; + } + + sched_need_reserve = true; + + const auto & eagle3_hparams = model->hparams; + + // Copy feature extraction layer indices from EAGLE3 model's hparams + eagle3.extract_layer_indices.assign( + eagle3_hparams.eagle3_extract_layers.begin(), + eagle3_hparams.eagle3_extract_layers.end() + ); + + // Allocate tensors array for extraction + eagle3.extract_tensors.resize(eagle3.extract_layer_indices.size(), nullptr); + + LLAMA_LOG_INFO("%s: EAGLE3 extraction enabled for layers [%d, %d, %d]\n", __func__, + eagle3.extract_layer_indices[0], + eagle3.extract_layer_indices[1], + eagle3.extract_layer_indices[2]); +} + +void llama_context::set_dflash(const llama_model * model) { + cparams.dflash_extract_enabled = !!model; + if (!cparams.dflash_extract_enabled) { + return; + } + + sched_need_reserve = true; + + const auto & dflash_hparams = model->hparams; + + dflash.extract_layer_indices.assign( + dflash_hparams.dflash_target_layer_ids.begin(), + dflash_hparams.dflash_target_layer_ids.end() + ); + + dflash.extract_tensors.resize(dflash.extract_layer_indices.size(), nullptr); + + LLAMA_LOG_INFO("%s: DFlash extraction enabled for layers [%d, %d, %d, %d, %d]\n", __func__, + dflash.extract_layer_indices[0], + dflash.extract_layer_indices[1], + dflash.extract_layer_indices[2], + dflash.extract_layer_indices[3], + dflash.extract_layer_indices[4]); +} + +const float * llama_context::get_dflash_target_features() const { + GGML_ASSERT(!dflash.target_features.empty() && "DFlash target features not extracted"); + return dflash.target_features.data(); +} + +void llama_context::set_dflash_accumulated_target_ctx(const float * data, int32_t n_embd, int32_t n_tokens) { + GGML_ASSERT(data != nullptr); + const size_t size = (size_t)n_embd * n_tokens; + // Store in cross struct (reusing T5 style cross-attention for accumulated target features fed to the DFlash decoder) + cross.n_embd = n_embd; + cross.n_enc = n_tokens; + cross.v_embd.resize(size); + std::memcpy(cross.v_embd.data(), data, size * sizeof(float)); +} + llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { + // DFlash decoder runs through encode path due to no kv-cache but it needs decoder graph type + if (model.arch == LLM_ARCH_DFLASH && dflash_decoder_ctx && gtype == LLM_GRAPH_TYPE_ENCODER) { + gtype = LLM_GRAPH_TYPE_DECODER; + } + if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -1225,6 +1308,30 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); + // EAGLE3: Fill g_embeddings for decoder input + if (model.arch == LLM_ARCH_EAGLE3 && gtype == LLM_GRAPH_TYPE_DECODER && !eagle3.g_embeddings.empty()) { + ggml_tensor * g_embd = ggml_graph_get_tensor(gf, "inp_g_embeddings"); + if (g_embd) { + ggml_backend_tensor_set(g_embd, eagle3.g_embeddings.data(), 0, ggml_nbytes(g_embd)); + } + } + + // temp fix DFlash: Fill position tensor for decoder + if (model.arch == LLM_ARCH_DFLASH && gtype == LLM_GRAPH_TYPE_DECODER && !cross.v_embd.empty()) { + const int64_t n_ctx = cross.n_enc; + const int64_t n_noise = ubatch.n_tokens; + const int64_t n_total = n_ctx + n_noise; + + ggml_tensor * pos_full = ggml_graph_get_tensor(gf, "inp_pos_full"); + if (pos_full) { + std::vector pos_data(n_total); + for (int64_t i = 0; i < n_total; ++i) { + pos_data[i] = (int32_t)i; + } + ggml_backend_tensor_set(pos_full, pos_data.data(), 0, n_total * sizeof(int32_t)); + } + } + //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } @@ -1235,6 +1342,15 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } + // EAGLE3: Extract intermediate layer features after graph execution + if (cparams.eagle3_extract_enabled && !eagle3.extract_tensors.empty()) { + extract_eagle3_features(ubatch); + } + + if (cparams.dflash_extract_enabled && !dflash.extract_tensors.empty()) { + extract_dflash_features(ubatch); + } + ret = GGML_STATUS_SUCCESS; return res; @@ -1250,7 +1366,15 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd_inp(); + // EAGLE3/DFlash: use concatenated features size from target for draft encoder input + int64_t n_embd = hparams.n_embd; + if (batch_inp.embd) { + if (model.arch == LLM_ARCH_EAGLE3) { + n_embd = 3 * hparams.eagle3_target_hidden_size; + } else if (model.arch == LLM_ARCH_DFLASH) { + n_embd = (int64_t) hparams.dflash_target_layer_ids.size() * hparams.n_embd; + } + } const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -1336,8 +1460,15 @@ int llama_context::encode(const llama_batch & batch_inp) { GGML_ASSERT(embd.data != nullptr); const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); + if (model.arch == LLM_ARCH_EAGLE3) { + // g_embeddings are stored temporarily in embd buffer + const int64_t out_embd = hparams.n_embd; + GGML_ASSERT(n_tokens * out_embd <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens * out_embd * sizeof(float)); + } else { + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); + } } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1730,7 +1861,8 @@ int llama_context::decode(const llama_batch & batch_inp) { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - if (t_embd && res->get_embd_pooled()) { + // For EAGLE3, don't override t_embd with t_embd_pooled - we need the prenorm value during eagle3 decoder autoregressive generation + if (t_embd && res->get_embd_pooled() && model.arch != LLM_ARCH_EAGLE3) { t_embd = res->get_embd_pooled(); } @@ -1745,7 +1877,40 @@ int llama_context::decode(const llama_batch & batch_inp) { if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + + // EAGLE3: Map draft vocab to target vocab + if (model.arch == LLM_ARCH_EAGLE3 && model.d2t) { + static thread_local std::vector eagle3_d2t_map; + static thread_local std::vector eagle3_draft_logits; + + const int64_t draft_vocab_size = t_logits->ne[0]; + const uint32_t last_idx = n_outputs - 1; + + // Load d2t mapping once (on first call) + if (eagle3_d2t_map.empty()) { + eagle3_d2t_map.resize(model.d2t->ne[0]); + ggml_backend_tensor_get(model.d2t, eagle3_d2t_map.data(), 0, eagle3_d2t_map.size() * sizeof(int64_t)); + } + + // Read only the last token's draft logits + eagle3_draft_logits.resize(draft_vocab_size); + const size_t last_offset = last_idx * draft_vocab_size * sizeof(float); + ggml_backend_tensor_get_async(backend_res, t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float)); + synchronize(); + + + // Map only the last token's draft logits to target vocab + float * last_logits_out = logits_out + last_idx * n_vocab; + std::fill(last_logits_out, last_logits_out + n_vocab, -std::numeric_limits::infinity()); + + for (int64_t j = 0; j < draft_vocab_size; j++) { + const int64_t target_id = j + eagle3_d2t_map[j]; + GGML_ASSERT(target_id >= 0 && target_id < n_vocab); + last_logits_out[target_id] = eagle3_draft_logits[j]; + } + } else { + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } } } @@ -2121,7 +2286,23 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + // EAGLE3: auto-detect encoder (embeddings+no target_model) or decoder (has target_model) + llm_graph_type gtype = LLM_GRAPH_TYPE_DEFAULT; + if (model.arch == LLM_ARCH_EAGLE3) { + if (cparams.embeddings && model.target_tok_embd == nullptr) { + gtype = LLM_GRAPH_TYPE_ENCODER; + } else if (model.target_tok_embd != nullptr) { + gtype = LLM_GRAPH_TYPE_DECODER; + } + } + if (model.arch == LLM_ARCH_DFLASH) { + if (cparams.embeddings && !dflash_decoder_ctx) { + gtype = LLM_GRAPH_TYPE_ENCODER; + } else if (dflash_decoder_ctx) { + gtype = LLM_GRAPH_TYPE_DECODER; + } + } + const auto gparams = graph_params(res, ubatch, mctx, gtype); res->reset(); @@ -2162,6 +2343,8 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.eagle3 =*/ &eagle3, + /*.dflash =*/ &dflash, /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2206,6 +2389,41 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } + // EAGLE3: Extract intermediate layer features if this is an extraction point + if (cparams.eagle3_extract_enabled) { + static constexpr const char * prefix = "eagle3_extract_"; + static constexpr size_t prefix_len = 15; // strlen("eagle3_extract_") + + if (strncmp(name, prefix, prefix_len) == 0) { + // Parse the extraction index from the name (e.g., "eagle3_extract_0" -> 0) + size_t extract_idx = 0; + if (sscanf(name + prefix_len, "%zu", &extract_idx) == 1 && extract_idx < eagle3.extract_tensors.size()) { + // Mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + // Store this tensor reference for post-execution extraction + eagle3.extract_tensors[extract_idx] = cur; + LLAMA_LOG_DEBUG("%s: EAGLE3 stored tensor reference for extraction: " + "index=%zu, layer=%d, target_layer=%d, tensor=%s\n", + __func__, extract_idx, il, + eagle3.extract_layer_indices[extract_idx], name); + } + } + } + + // DFlash: Extract intermediate layer features if this is an extraction point + if (cparams.dflash_extract_enabled) { + static constexpr const char * prefix = "dflash_extract_"; + static constexpr size_t prefix_len = 15; + + if (strncmp(name, prefix, prefix_len) == 0) { + size_t extract_idx = 0; + if (sscanf(name + prefix_len, "%zu", &extract_idx) == 1 && extract_idx < dflash.extract_tensors.size()) { + ggml_set_output(cur); + dflash.extract_tensors[extract_idx] = cur; + } + } + } + // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; @@ -2224,6 +2442,90 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +void llama_context::extract_eagle3_features(const llama_ubatch & ubatch) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_embd = model.hparams.n_embd; + const size_t n_layers = eagle3.extract_tensors.size(); + + // Allocate storage for concatenated features + const int64_t n_embd_concat = n_embd * n_layers; + eagle3.target_features.resize(n_embd_concat * n_tokens); + + // Temporary buffer to hold layer features before transposing + static thread_local std::vector temp_layer_features; + temp_layer_features.resize(n_embd * n_tokens); + + LLAMA_LOG_DEBUG("%s: Start to extract EAGLE3 features: %zu layers, %lld tokens, %lld embd\n", + __func__, n_layers, (long long)n_tokens, (long long)n_embd); + + // Extract each layer's features and interleave into token-major layout + for (size_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + ggml_tensor * tensor = eagle3.extract_tensors[layer_idx]; + GGML_ASSERT(tensor != nullptr && "EAGLE3 extraction tensor is null"); + + // Get the backend where this tensor is stored + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), tensor); + GGML_ASSERT(backend != nullptr && "EAGLE3 tensor has no backend"); + + // Verify tensor shape: should be [n_embd, n_tokens] + GGML_ASSERT(tensor->ne[0] == n_embd && tensor->ne[1] == n_tokens && + "EAGLE3 extraction tensor has unexpected shape"); + + // Get layer features to temp buffer + const size_t size_bytes = n_embd * n_tokens * sizeof(float); + ggml_backend_tensor_get_async(backend, tensor, temp_layer_features.data(), 0, size_bytes); + ggml_backend_sched_synchronize(sched.get()); + + // Then copy to correct position in target_features + // target_features layout: [token_0_all_layers, token_1_all_layers, ...] + // Each token has [layer_0_embd, layer_1_embd, layer_2_embd] + for (int64_t token_idx = 0; token_idx < n_tokens; ++token_idx) { + // Source: temp_layer_features[token_idx * n_embd ... (token_idx + 1) * n_embd - 1] + const float * src = temp_layer_features.data() + token_idx * n_embd; + // Dest: target_features[token_idx * n_embd_concat + layer_idx * n_embd] + float * dest = eagle3.target_features.data() + token_idx * n_embd_concat + layer_idx * n_embd; + std::memcpy(dest, src, n_embd * sizeof(float)); + } + } + +} + +void llama_context::extract_dflash_features(const llama_ubatch & ubatch) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_embd = model.hparams.n_embd; + const size_t n_layers = dflash.extract_tensors.size(); + + const int64_t n_embd_concat = n_embd * n_layers; + dflash.target_features.resize(n_embd_concat * n_tokens); + + static thread_local std::vector temp_layer_features; + temp_layer_features.resize(n_embd * n_tokens); + + LLAMA_LOG_DEBUG("%s: Start to extract DFlash features: %zu layers, %lld tokens, %lld embd\n", + __func__, n_layers, (long long)n_tokens, (long long)n_embd); + + for (size_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + ggml_tensor * tensor = dflash.extract_tensors[layer_idx]; + GGML_ASSERT(tensor != nullptr && "DFlash extraction tensor is null"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), tensor); + GGML_ASSERT(backend != nullptr && "DFlash tensor has no backend"); + + GGML_ASSERT(tensor->ne[0] == n_embd && tensor->ne[1] == n_tokens && + "DFlash extraction tensor has unexpected shape"); + + const size_t size_bytes = n_embd * n_tokens * sizeof(float); + ggml_backend_tensor_get_async(backend, tensor, temp_layer_features.data(), 0, size_bytes); + ggml_backend_sched_synchronize(sched.get()); + + for (int64_t token_idx = 0; token_idx < n_tokens; ++token_idx) { + const float * src = temp_layer_features.data() + token_idx * n_embd; + float * dest = dflash.target_features.data() + token_idx * n_embd_concat + layer_idx * n_embd; + std::memcpy(dest, src, n_embd * sizeof(float)); + } + } +} + // // state save/load // @@ -2636,7 +2938,7 @@ void llama_context::perf_reset() { n_reused = 0; } -std::map llama_context::memory_breakdown() const { +llama_memory_breakdown llama_context::memory_breakdown() const { std::map ret; for (const auto & [buft, size] : model.memory_breakdown()) { ret[buft].model += size; @@ -2916,6 +3218,7 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.target_model =*/ nullptr, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, }; @@ -2931,6 +3234,19 @@ llama_context * llama_init_from_model( return nullptr; } + // Auto-setup for EAGLE3: set target embedding if target_model is provided + if (model->arch == LLM_ARCH_EAGLE3 && params.target_model) { + model->target_tok_embd = params.target_model->tok_embd; + LLAMA_LOG_INFO("%s: EAGLE3 auto-setup: using target model's embedding layer\n", __func__); + } + + // Auto-setup for DFlash: set target embedding + lm_head if target_model is provided + if (model->arch == LLM_ARCH_DFLASH && params.target_model) { + model->target_tok_embd = params.target_model->tok_embd; + model->target_output = params.target_model->output; + LLAMA_LOG_INFO("%s: DFlash auto-setup: using target model's embedding + lm_head layers\n", __func__); + } + if (params.n_batch == 0 && params.n_ubatch == 0) { LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); return nullptr; @@ -3212,6 +3528,22 @@ int32_t llama_set_adapter_cvec( return res ? 0 : -1; } +// +// eagle3 (tmp) +// + +void llama_set_eagle3( + llama_context * ctx, + const llama_model * model) { + ctx->set_eagle3(model); +} + +void llama_set_dflash( + llama_context * ctx, + const llama_model * model) { + ctx->set_dflash(model); +} + // // memory // @@ -3493,142 +3825,6 @@ void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } -void llama_memory_breakdown_print(const struct llama_context * ctx) { - const auto & devices = ctx->get_model().devices; - - std::map memory_breakdown = ctx->memory_breakdown(); - - std::vector> table_data; - table_data.reserve(devices.size()); - const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; - const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; - const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; - - table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); - - constexpr size_t MiB = 1024 * 1024; - const std::vector desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; - - // track seen buffer types to avoid double counting: - std::set seen_buffer_types; - - // accumulative memory breakdown for each device and for host: - std::vector mb_dev(devices.size()); - llama_memory_breakdown_data mb_host; - - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (ggml_backend_buft_is_host(buft)) { - mb_host.model += mb.model; - mb_host.context += mb.context; - mb_host.compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (dev) { - int i_dev = -1; - for (size_t i = 0; i < devices.size(); i++) { - if (devices[i].dev == dev) { - i_dev = i; - break; - } - } - if (i_dev != -1) { - mb_dev[i_dev].model += mb.model; - mb_dev[i_dev].context += mb.context; - mb_dev[i_dev].compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - } - } - - // print memory breakdown for each device: - for (size_t i = 0; i < devices.size(); i++) { - ggml_backend_dev_t dev = devices[i].dev; - llama_memory_breakdown_data mb = mb_dev[i]; - - const std::string name = ggml_backend_dev_name(dev); - std::string desc = ggml_backend_dev_description(dev); - for (const std::string & prefix : desc_prefixes_strip) { - if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { - desc = desc.substr(prefix.length()); - } - } - - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - - const size_t self = mb.model + mb.context + mb.compute; - const size_t unaccounted = total - self - free; - - table_data.push_back({ - template_gpu, - " - " + name + " (" + desc + ")", - std::to_string(total / MiB), - std::to_string(free / MiB), - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - std::to_string(unaccounted / MiB)}); - } - - // print memory breakdown for host: - { - const size_t self = mb_host.model + mb_host.context + mb_host.compute; - table_data.push_back({ - template_other, - " - Host", - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb_host.model / MiB), - std::to_string(mb_host.context / MiB), - std::to_string(mb_host.compute / MiB), - ""}); // unaccounted - } - - // print memory breakdown for all remaining buffer types: - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (seen_buffer_types.count(buft) == 1) { - continue; - } - const std::string name = ggml_backend_buft_name(buft); - const size_t self = mb.model + mb.context + mb.compute; - table_data.push_back({ - template_other, - " - " + name, - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - ""}); // unaccounted - seen_buffer_types.insert(buft); - } - - for (size_t j = 1; j < table_data[0].size(); j++) { - size_t max_len = 0; - for (const auto & td : table_data) { - max_len = std::max(max_len, td[j].length()); - } - for (auto & td : table_data) { - td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); - } - } - for (const auto & td : table_data) { - LLAMA_LOG_INFO(td[0].c_str(), - __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), - td[6].c_str(), td[7].c_str(), td[8].c_str()); - } -} - // // training // @@ -3659,3 +3855,50 @@ void llama_opt_epoch( callback_train, callback_eval); } + +// +// EAGLE3 member functions +// + +const float * llama_context::get_eagle3_target_features() const { + GGML_ASSERT(!eagle3.target_features.empty() && "EAGLE3 target features not extracted - call llama_encode() on target model first"); + return eagle3.target_features.data(); +} + +void llama_context::set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens) { + GGML_ASSERT(g_embd != nullptr && "g_embeddings cannot be null"); + GGML_ASSERT(n_embd > 0 && n_tokens > 0 && "invalid dimensions"); + + const size_t size = n_embd * n_tokens; + eagle3.g_embeddings.resize(size); + std::memcpy(eagle3.g_embeddings.data(), g_embd, size * sizeof(float)); +} + +// +// C API wrappers +// + +const float * llama_get_eagle3_target_features(llama_context * ctx) { + return ctx->get_eagle3_target_features(); +} + +void llama_set_eagle3_g_embeddings(llama_context * ctx, const float * g_embd, int32_t n_embd, int32_t n_tokens) { + ctx->set_eagle3_g_embeddings(g_embd, n_embd, n_tokens); +} + +const float * llama_get_dflash_target_features(llama_context * ctx) { + return ctx->get_dflash_target_features(); +} + +void llama_set_dflash_accumulated_target_ctx(llama_context * ctx, const float * data, int32_t n_embd, int32_t n_tokens) { + ctx->set_dflash_accumulated_target_ctx(data, n_embd, n_tokens); +} + + +// +// ext +// + +llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { + return ctx->memory_breakdown(); +} diff --git a/src/llama-context.h b/src/llama-context.h index e0d0085c1c3..86f0d81c0cc 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-ext.h" #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" @@ -22,17 +23,6 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; -// "memory" as in physical memory for a buffer type, in bytes -struct llama_memory_breakdown_data { - size_t model = 0; // memory allocated for the model - size_t context = 0; // memory allocated for the context - size_t compute = 0; // memory allocated for temporary compute buffers - - size_t total() const { - return model + context + compute; - } -}; - struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -116,6 +106,10 @@ struct llama_context { int32_t il_start, int32_t il_end); + // TODO: tmp + void set_eagle3(const llama_model * model); + void set_dflash(const llama_model * model); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation @@ -172,7 +166,7 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); - std::map memory_breakdown() const; + llama_memory_breakdown memory_breakdown() const; // // training @@ -232,6 +226,18 @@ struct llama_context { ggml_cgraph * graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); + // EAGLE3: Get pointer to target model features extracted for EAGLE3 encoder + const float * get_eagle3_target_features() const; + + // EAGLE3: Set g_embeddings from encoder output for decoder input + void set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens); + + // DFlash: Get pointer to target model features extracted for DFlash encoder + const float * get_dflash_target_features() const; + + // DFlash: Set accumulated target_ctx from encoder output for decoder input + void set_dflash_accumulated_target_ctx(const float * data, int32_t n_embd, int32_t n_tokens); + bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); private: @@ -243,6 +249,12 @@ struct llama_context { llm_graph_cb graph_get_cb() const; + // EAGLE3: Extract intermediate layer features from target model + void extract_eagle3_features(const llama_ubatch & ubatch); + + // DFlash: Extract intermediate layer features from target model + void extract_dflash_features(const llama_ubatch & ubatch); + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -263,6 +275,15 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably + mutable llama_eagle3 eagle3; // EAGLE3 draft model support - stores features from target model + // mutable because it's modified during graph building (const function) + + mutable llama_dflash dflash; + + // temp fix: avoid DFlash encoder/decoder mis-detection. They share one model_dft, + // so shared model fields cannot safely identify the decoder (caused OOM). + bool dflash_decoder_ctx = false; + std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 9d359474132..906bfbe36c1 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -38,6 +38,8 @@ struct llama_cparams { bool warmup; bool op_offload; bool kv_unified; + bool eagle3_extract_enabled; // enable layer extraction for EAGLE3 speculative decoding + bool dflash_extract_enabled; // enable layer extraction for DFlash speculative decoding bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/src/llama-ext.h b/src/llama-ext.h index 2ffb77934e1..8ce29d217cb 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -1,8 +1,12 @@ #pragma once +// this is a staging header for new llama.cpp API +// breaking changes and C++ are allowed. everything here should be considered WIP + #include "llama.h" #include +#include // Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. LLAMA_API struct ggml_cgraph * llama_graph_reserve( @@ -14,7 +18,6 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve( // Get the default ggml_type for a given ftype. LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype); -// Quantization state. struct quantize_state_impl; LLAMA_API quantize_state_impl * llama_quant_init( @@ -54,3 +57,34 @@ LLAMA_API void llama_quant_compute_types( ggml_tensor ** tensors, ggml_type * result_types, size_t n_tensors); + +// +// device memory querying +// + +// "memory" as in physical memory for a buffer type, in bytes +struct llama_memory_breakdown_data { + size_t model = 0; // memory allocated for the model + size_t context = 0; // memory allocated for the context + size_t compute = 0; // memory allocated for temporary compute buffers + + size_t total() const { + return model + context + compute; + } +}; + +struct llama_device_memory_data { + int64_t total; + int64_t free; + llama_memory_breakdown_data mb; +}; + +// TODO: convert to C-style data structure +using llama_memory_breakdown = std::map; + +LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model); +LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); + +LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); + +LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 66cffa461ac..9fabd242e76 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -946,6 +946,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + eagle3 (params.eagle3), + dflash (params.dflash), samplers (params.samplers), cb_func (params.cb), res (params.res), @@ -1077,9 +1079,9 @@ llm_graph_qkv llm_graph_context::build_qkv( // fused QKV path ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s); cb(qkv, "wqkv", il); - if (layer.bqkv) { - qkv = ggml_add(ctx0, qkv, layer.bqkv); - cb(qkv, "bqkv", il); + if (layer.wqkv_b) { + qkv = ggml_add(ctx0, qkv, layer.wqkv_b); + cb(qkv, "wqkv_b", il); } if (hparams.f_clamp_kqv > 0.0f) { qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); @@ -1097,8 +1099,8 @@ llm_graph_qkv llm_graph_context::build_qkv( // separate Q/K/V path Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); cb(Qcur, "Qcur", il); - if (layer.bq) { - Qcur = ggml_add(ctx0, Qcur, layer.bq); + if (layer.wq_b) { + Qcur = ggml_add(ctx0, Qcur, layer.wq_b); cb(Qcur, "Qcur", il); } if (hparams.f_clamp_kqv > 0.0f) { @@ -1107,8 +1109,8 @@ llm_graph_qkv llm_graph_context::build_qkv( } Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); cb(Kcur, "Kcur", il); - if (layer.bk) { - Kcur = ggml_add(ctx0, Kcur, layer.bk); + if (layer.wk_b) { + Kcur = ggml_add(ctx0, Kcur, layer.wk_b); cb(Kcur, "Kcur", il); } if (hparams.f_clamp_kqv > 0.0f) { @@ -1117,8 +1119,8 @@ llm_graph_qkv llm_graph_context::build_qkv( } Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); cb(Vcur, "Vcur", il); - if (layer.bv) { - Vcur = ggml_add(ctx0, Vcur, layer.bv); + if (layer.wv_b) { + Vcur = ggml_add(ctx0, Vcur, layer.wv_b); cb(Vcur, "Vcur", il); } if (hparams.f_clamp_kqv > 0.0f) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 5cb1756c6a9..1925a275d8a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -73,6 +73,44 @@ struct llama_cross { std::vector> seq_ids_enc; }; +// EAGLE3 support - stores intermediate features from target model +struct llama_eagle3 { + // Configuration: which layers to extract from target model + std::vector extract_layer_indices; + + // Extracted features from target model (for encoder input) + // Concatenated [layer_l, layer_m, layer_h] embeddings + // Shape: [n_layers * n_embd, n_tokens] where n_layers = extract_layer_indices.size() + std::vector target_features; + + // Encoder output (for decoder input) + std::vector g_embeddings; + + // Tensor references for feature extraction from target model + std::vector extract_tensors; + + // Clear all stored data + void clear() { + target_features.clear(); + g_embeddings.clear(); + extract_tensors.clear(); + } +}; + +// DFlash intermediate results struct (similar to Eagle3) +struct llama_dflash { + std::vector extract_layer_indices; + + std::vector target_features; + + std::vector extract_tensors; + + void clear() { + target_features.clear(); + extract_tensors.clear(); + } +}; + struct llm_graph_params; // @@ -544,6 +582,8 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_eagle3 * eagle3; // non-const: we write extracted features here + llama_dflash * dflash; std::map samplers; @@ -758,6 +798,8 @@ struct llm_graph_context { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_eagle3 * eagle3; // non-const: we write extracted features here + llama_dflash * dflash; std::map samplers; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c2000c77c37..fdd5a03bea8 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -116,6 +116,7 @@ struct llama_hparams { float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; float rope_freq_scale_train_swa = 1.0f; + float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -209,6 +210,21 @@ struct llama_hparams { // qwen3vl deepstack uint32_t n_deepstack_layers = 0; + // EAGLE3 draft model - layer indices to extract from target model + // e.g., for 32-layer target: [2, 16, 29] (low, middle, high) + std::array eagle3_extract_layers = {0, 0, 0}; + + // EAGLE3 draft model - target model hidden size + uint32_t eagle3_target_hidden_size = 0; + + // EAGLE3 draft model - apply hidden_norm before storing residual + bool eagle3_norm_before_residual = false; + + // DFlash draft model + std::array dflash_target_layer_ids = {}; + uint32_t dflash_block_size = 16; + uint32_t dflash_mask_token_id = 0; + // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d..832fc990c89 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -503,6 +503,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); // store DFlash 5 layer ids template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d9781d7d275..47668954f59 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1,6 +1,7 @@ #include "llama-model.h" #include "llama-arch.h" +#include "llama-ext.h" #include "llama-hparams.h" #include "llama-impl.h" #include "llama-mmap.h" @@ -77,11 +78,23 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const ggml_tensor * tensor_axis_0; uint32_t il; - size_t rotation; + size_t rotation; // when assigning tensor slices, rotate how the rounding is done for more even allocation }; auto get_tensor_config_impl = [&]( const ggml_backend_meta_split_axis axis, const std::string & suffix = "", const std::string & suffix_fallback = "") -> tensor_config { + // the layers in a tensor can be inhomogeneous, if the pattern is cleanly divided by the number of GPUs there can be aliasing effects, + // count only the same type of previous layers to avoid this + auto get_il_eff = [&](const size_t il){ + size_t ret = 0; + const bool il_is_recurrent = hparams.is_recurrent(il); + const bool il_is_swa = hparams.is_swa(il); + for (size_t il_prev = 0; il_prev < il; il_prev++) { + ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + } + return ret; + }; + uint32_t il; std::string prefix; size_t rotation; @@ -90,13 +103,13 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str GGML_ASSERT(length_prefix != std::string::npos); prefix = tensor_name.substr(0, length_prefix + 1); il = std::stoull(tensor_name.substr(4, length_prefix)); - rotation = il % ud->n_devices; + rotation = get_il_eff(il) % ud->n_devices; } else if (tensor_name.substr(0, 6) == "cache_") { const size_t layer_index_start = tensor_name.find("_l", 6); GGML_ASSERT(layer_index_start != std::string::npos); il = std::stoull(tensor_name.substr(layer_index_start + 2)); prefix = "blk." + std::to_string(il) + "."; - rotation = il % ud->n_devices; + rotation = get_il_eff(il) % ud->n_devices; } else { il = 0; rotation = hparams.n_layer % ud->n_devices; @@ -724,6 +737,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); + if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) { + if (hparams.n_expert <= 1) { + hparams.n_expert = 0; + hparams.n_expert_used = 0; + } + } + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); @@ -802,6 +822,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false); // non-transformer models do not have attention heads if (hparams.n_head() > 0) { @@ -2579,9 +2600,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } switch (hparams.n_embd) { case 1024: type = LLM_TYPE_0_5B; break; @@ -2728,6 +2758,57 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EAGLE3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // EAGLE3 layer extraction configuration + // Use array (has template instantiation), then copy first 3 elements + std::array extract_layers_tmp = {}; + if (!ml.get_key_or_arr(LLM_KV_EAGLE3_EXTRACT_LAYERS, extract_layers_tmp, 3, false)) { + throw std::runtime_error("EAGLE3 model requires 'extract_layers' in GGUF metadata"); + } + std::copy_n(extract_layers_tmp.begin(), 3, hparams.eagle3_extract_layers.begin()); + LLAMA_LOG_INFO("%s: EAGLE3 extract_layers = [%d, %d, %d]\n", __func__, + hparams.eagle3_extract_layers[0], + hparams.eagle3_extract_layers[1], + hparams.eagle3_extract_layers[2]); + + // EAGLE3 target model hidden size + ml.get_key(LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, hparams.eagle3_target_hidden_size); + LLAMA_LOG_INFO("%s: EAGLE3 target_hidden_size = %u (draft n_embd = %u)\n", __func__, + hparams.eagle3_target_hidden_size, hparams.n_embd); + + // EAGLE3 norm_before_residual (optional, default false) + // compatible with Readhat eagle3 speculator model + ml.get_key(LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, hparams.eagle3_norm_before_residual, false); + if (hparams.eagle3_norm_before_residual) { + LLAMA_LOG_INFO("%s: EAGLE3 norm_before_residual = true\n", __func__); + } + + type = LLM_TYPE_UNKNOWN; + } break; + case LLM_ARCH_DFLASH: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key(LLM_KV_DFLASH_BLOCK_SIZE, hparams.dflash_block_size, false); + ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false); + + if (!ml.get_key_or_arr(LLM_KV_DFLASH_TARGET_LAYER_IDS, hparams.dflash_target_layer_ids, 5, false)) { + throw std::runtime_error("DFlash model requires 'target_layer_ids' in GGUF metadata"); + } + LLAMA_LOG_INFO("%s: DFlash extract_layers = [%d, %d, %d, %d, %d]\n", __func__, + hparams.dflash_target_layer_ids[0], + hparams.dflash_target_layer_ids[1], + hparams.dflash_target_layer_ids[2], + hparams.dflash_target_layer_ids[3], + hparams.dflash_target_layer_ids[4]); + + LLAMA_LOG_INFO("%s: DFlash block_size = %u, mask_token_id = %u\n", + __func__, hparams.dflash_block_size, hparams.dflash_mask_token_id); + + type = LLM_TYPE_UNKNOWN; + } break; case LLM_ARCH_COGVLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3098,14 +3179,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); if (layer.wqkv) { - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); } else { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); } }; @@ -3138,7 +3219,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3201,7 +3282,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -3336,9 +3417,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); } - // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); if (n_ff > 0) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3558,10 +3638,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3602,8 +3682,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3719,23 +3799,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; // JinaBertLayer - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3783,10 +3856,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3819,10 +3892,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); @@ -3889,7 +3962,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4068,7 +4141,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -4127,7 +4200,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); @@ -4291,10 +4364,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4329,7 +4402,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4646,7 +4719,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4890,7 +4963,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } // feed forward (w/ optional biases) @@ -5152,10 +5225,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5570,10 +5643,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5612,10 +5685,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // attention biases - all have shape n_embd (output dimension of projections) - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5918,7 +5991,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5987,7 +6060,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; @@ -6808,7 +6881,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6890,7 +6963,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // attention layers (with optional bias) create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); @@ -6942,6 +7015,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7026,7 +7100,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); @@ -7191,7 +7265,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); @@ -7231,6 +7305,97 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_EAGLE3: + { + const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; + const int64_t n_embd_attn_input = 2 * n_embd; + + // Get vocab size from the d2t tensor in the GGUF file (optional - only needed if EAGLE3 has different vocab_size than target) + // d2t: draft to target vocabulary mapping + int64_t n_draft_vocab = n_vocab; // Default: same as target vocab + const struct ggml_tensor * d2t_meta = ml.get_tensor_meta("d2t"); + if (d2t_meta) { + n_draft_vocab = d2t_meta->ne[0]; // update draft vocab size + d2t = create_tensor(tn(LLM_TENSOR_EAGLE3_D2T), {n_draft_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using d2t mapping (draft_vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } else { + d2t = nullptr; // no d2t, use default vocab size + LLAMA_LOG_INFO("%s: EAGLE3 without d2t - sharing same vocab_size with target (vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } + + // Feature fusion layer: projects 3 target layers to draft hidden size + fc = create_tensor(tn(LLM_TENSOR_EAGLE3_FC, "weight"), {n_embd_target_features, n_embd}, 0); + + // Output layer (uses draft vocab size) + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, 0); + + // Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own) + const struct ggml_tensor * tok_embd_meta = ml.get_tensor_meta(tn(LLM_TENSOR_TOKEN_EMBD, "weight").str().c_str()); + if (tok_embd_meta) { + const int64_t n_target_vocab = tok_embd_meta->ne[1]; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_target_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using its own token_embd (vocab = %lld)\n", __func__, (long long)n_target_vocab); + } + + // Single decoder layer + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // input_layernorm: applied to token embeddings + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Attention takes input_embeds_normed + fused_target_normed as input + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_attn_input, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_attn_input, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_attn_input, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // EAGLE-3 specific: hidden_norm applied to fused target features + layer.eagle3_hidden_norm = create_tensor(tn(LLM_TENSOR_EAGLE3_HIDDEN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // rope_freqs for llama3 rope scaling (optional - only if EAGLE3 config has rope_scaling) + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_DFLASH: + { + const int64_t n_target_layer_ids = (int64_t)hparams.dflash_target_layer_ids.size(); + const int64_t n_embd_target_features = n_target_layer_ids * n_embd; + + fc = create_tensor(tn(LLM_TENSOR_DFLASH_FC, "weight"), {n_embd_target_features, n_embd}, 0); + dflash_hidden_norm = create_tensor(tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_KIMI_LINEAR: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7422,7 +7587,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -8188,114 +8353,114 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; - for (auto label : classifier_labels) { + for (const auto & label : classifier_labels) { LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } - } - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_QWEN35 || - arch == LLM_ARCH_QWEN35MOE || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } } vocab.print_info(); @@ -8418,6 +8583,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_RND1: + case LLM_ARCH_DFLASH: // current DFlash decoder doesn't support KV-cache due to cross_attn + self_attn (no mask) { res = nullptr; } break; @@ -8962,6 +9128,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { llm = std::make_unique(*this, params); @@ -9007,6 +9174,22 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_EAGLE3: + { + if (params.gtype == LLM_GRAPH_TYPE_ENCODER) { + llm = std::make_unique(*this, params); + } else { + llm = std::make_unique(*this, params); + } + } break; + case LLM_ARCH_DFLASH: + { + if (params.gtype == LLM_GRAPH_TYPE_ENCODER) { + llm = std::make_unique(*this, params); + } else { + llm = std::make_unique(*this, params); + } + } break; case LLM_ARCH_COGVLM: { llm = std::make_unique(*this, params); @@ -9137,6 +9320,14 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } +int32_t llama_model_dflash_block_size(const llama_model * model) { + return (int32_t) model->hparams.dflash_block_size; +} + +int32_t llama_model_dflash_mask_token_id(const llama_model * model) { + return (int32_t) model->hparams.dflash_mask_token_id; +} + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } @@ -9225,6 +9416,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_EAGLE3: case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: @@ -9250,6 +9442,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_DFLASH: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_RND1: case LLM_ARCH_OLMO2: @@ -9311,6 +9504,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4_MOE: return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_HUNYUAN_VL: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); @@ -9445,3 +9641,18 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +int32_t llama_model_n_expert(const struct llama_model * model) { + return model->hparams.n_expert; +} + +int32_t llama_model_n_devices(const struct llama_model * model) { + return (int32_t)model->devices.size(); +} + +ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) { + if (i < 0 || i >= (int)model->devices.size()) { + return nullptr; + } + return model->devices[i].dev; +} diff --git a/src/llama-model.h b/src/llama-model.h index 67349e2d6ff..199cc45ca49 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -246,6 +246,8 @@ struct llama_layer { struct ggml_tensor * wkv_b = nullptr; struct ggml_tensor * wk_b = nullptr; struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wqkv_b = nullptr; + struct ggml_tensor * wo_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; @@ -256,13 +258,6 @@ struct llama_layer { struct ggml_tensor * wo_enc = nullptr; struct ggml_tensor * wqkv_gate = nullptr; - // attention bias - struct ggml_tensor * bq = nullptr; - struct ggml_tensor * bk = nullptr; - struct ggml_tensor * bv = nullptr; - struct ggml_tensor * bo = nullptr; - struct ggml_tensor * bqkv = nullptr; - // relative position bias struct ggml_tensor * attn_rel_b = nullptr; struct ggml_tensor * attn_rel_b_enc = nullptr; @@ -470,6 +465,9 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // eagle3 + struct ggml_tensor * eagle3_hidden_norm = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias struct ggml_tensor * ssm_q_conv = nullptr; @@ -555,6 +553,17 @@ struct llama_model { struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; + // eagle3 + struct ggml_tensor * fc = nullptr; // feature fusion layer + struct ggml_tensor * d2t = nullptr; // draft to target vocabulary mapping + // Reference to target model's embedding layer + // This allows EAGLE3 to use target model's embeddings without copying + struct ggml_tensor * target_tok_embd = nullptr; + + // dflash + struct ggml_tensor * dflash_hidden_norm = nullptr; + struct ggml_tensor * target_output = nullptr; // reference to target model's lm_head + std::vector layers; //Dense linear projections for SentenceTransformers models like embeddinggemma diff --git a/src/llama.cpp b/src/llama.cpp index 484372d8d10..e9c3028585d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -46,725 +46,6 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } -struct llama_device_memory_data { - int64_t total; - int64_t free; - llama_memory_breakdown_data mb; -}; - -static std::vector llama_get_device_memory_data( - const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams, - std::vector & devs, uint32_t & hp_ngl, uint32_t & hp_n_ctx_train, uint32_t & hp_n_expert, - const ggml_log_level log_level) { - struct user_data_t { - struct { - ggml_log_callback callback; - void * user_data; - } original_logger; - ggml_log_level min_level; // prints below this log level go to debug log - }; - user_data_t ud; - llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data); - ud.min_level = log_level; - - llama_log_set([](ggml_log_level level, const char * text, void * user_data) { - const user_data_t * ud = (const user_data_t *) user_data; - const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG; - ud->original_logger.callback(level_eff, text, ud->original_logger.user_data); - }, &ud); - - llama_model_params mparams_copy = *mparams; - mparams_copy.no_alloc = true; - mparams_copy.use_mmap = false; - mparams_copy.use_mlock = false; - - llama_model * model = llama_model_load_from_file(path_model, mparams_copy); - if (model == nullptr) { - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to load model"); - } - - llama_context * ctx = llama_init_from_model(model, *cparams); - if (ctx == nullptr) { - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to create llama_context from model"); - } - - std::vector ret(model->devices.size()); - - std::map memory_breakdown = ctx->memory_breakdown(); - - for (const auto & [buft, mb] : memory_breakdown) { - if (ggml_backend_buft_is_host(buft)) { - continue; - } - - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - continue; - } - for (size_t i = 0; i < ret.size(); i++) { - if (model->devices[i].dev == dev) { - ret[i].mb.model += mb.model; - ret[i].mb.context += mb.context; - ret[i].mb.compute += mb.compute; - break; - } - } - } - for (size_t i = 0; i < ret.size(); i++) { - size_t free; - size_t total; - ggml_backend_dev_memory(model->devices[i].dev, &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - ret[i].free = free; - ret[i].total = total; - } - - devs = model->devices; - hp_ngl = model->hparams.n_layer; - hp_n_ctx_train = model->hparams.n_ctx_train; - hp_n_expert = model->hparams.n_expert; - - llama_memory_breakdown_print(ctx); // goes to debug log - - llama_free(ctx); - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - return ret; -} - -// enum to identify part of a layer for distributing its tensors: -enum layer_fraction_t { - LAYER_FRACTION_NONE = 0, // nothing - LAYER_FRACTION_ATTN = 1, // attention - LAYER_FRACTION_UP = 2, // attention + up - LAYER_FRACTION_GATE = 3, // attention + up + gate - LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights -}; -// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue - -class llama_params_fit_exception : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - -static void llama_params_fit_impl( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { - if (mparams->split_mode == LLAMA_SPLIT_MODE_TENSOR) { - throw llama_params_fit_exception("llama_params_fit is not implemented for SPLIT_MODE_TENSOR, abort"); - } - constexpr int64_t MiB = 1024*1024; - typedef std::vector dmds_t; - const llama_model_params default_mparams = llama_model_default_params(); - - std::vector devs; - uint32_t hp_ngl = 0; // hparams.n_gpu_layers - uint32_t hp_nct = 0; // hparams.n_ctx_train - uint32_t hp_nex = 0; // hparams.n_expert - - // step 1: get data for default parameters and check whether any changes are necessary in the first place - - LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__); - const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - const size_t nd = devs.size(); // number of devices - if (nd == 0) { - LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__); - return; - } - - std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits - margins.reserve(nd); - for (size_t id = 0; id < nd; id++) { - margins.push_back(margins_s[id]); - } - - std::vector dev_names; - { - dev_names.reserve(nd); - size_t max_length = 0; - for (const llama_device & dev : devs) { - std::string name = ggml_backend_dev_name(dev.dev); - name += " ("; - name += ggml_backend_dev_description(dev.dev); - name += ")"; - dev_names.push_back(name); - max_length = std::max(max_length, name.length()); - } - for (std::string & dn : dev_names) { - dn.insert(dn.end(), max_length - dn.length(), ' '); - } - } - - int64_t sum_free = 0; - int64_t sum_projected_free = 0; - int64_t sum_projected_used = 0; - int64_t sum_projected_model = 0; - std::vector projected_free_per_device; - projected_free_per_device.reserve(nd); - - if (nd > 1) { - LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); - } - for (size_t id = 0; id < nd; id++) { - const llama_device_memory_data & dmd = dmds_full[id]; - - const int64_t projected_used = dmd.mb.total(); - const int64_t projected_free = dmd.free - projected_used; - projected_free_per_device.push_back(projected_free); - - sum_free += dmd.free; - sum_projected_used += projected_used; - sum_projected_free += projected_free; - sum_projected_model += dmd.mb.model; - - if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); - } - } - assert(sum_free >= 0 && sum_projected_used >= 0); - LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", - __func__, sum_projected_used/MiB, sum_free/MiB); - if (nd == 1) { - if (projected_free_per_device[0] >= margins[0]) { - LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); - return; - } - } else { - bool changes_needed = false; - for (size_t id = 0; id < nd; id++) { - if (projected_free_per_device[id] < margins[id]) { - changes_needed = true; - break; - } - } - if (!changes_needed) { - LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); - return; - } - } - - // step 2: try reducing memory use by reducing the context size - - { - int64_t global_surplus = sum_projected_free; - for (size_t id = 0; id < nd; id++) { - global_surplus -= margins[id]; - } - if (global_surplus < 0) { - if (nd == 1) { - LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", - __func__, margins[0]/MiB, -global_surplus/MiB); - } else { - LLAMA_LOG_INFO( - "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, -global_surplus/MiB); - } - if (cparams->n_ctx == 0) { - if (hp_nct > n_ctx_min) { - int64_t sum_used_target = sum_free; - for (size_t id = 0; id < nd; id++) { - sum_used_target -= margins[id]; - } - if (nd > 1) { - // for multiple devices we need to be more conservative in terms of how much context we think can fit: - // - for dense models only whole layers can be assigned to devices - // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer - // - on average we expect a waste of 0.5 layers/tensors per device - // - use slightly more than the expected average for nd devices to be safe - const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); - sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); - } - - int64_t sum_projected_used_min_ctx = 0; - cparams->n_ctx = n_ctx_min; - const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const auto & dmd : dmds_min_ctx) { - sum_projected_used_min_ctx += dmd.mb.total(); - } - if (sum_used_target > sum_projected_used_min_ctx) { - // linear interpolation between minimum and maximum context size: - cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx) - / (sum_projected_used - sum_projected_used_min_ctx); - cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend - - const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min); - const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - if (nd == 1) { - LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); - return; - } - LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); - } else { - const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - } - } else { - if (n_ctx_min == UINT32_MAX) { - LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct); - } else { - LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", - __func__, hp_nct, n_ctx_min); - } - } - } else { - LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); - } - } - } - - if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { - throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); - } - if (nd > 1) { - if (!tensor_split) { - throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort"); - } - if (mparams->tensor_split) { - for (size_t id = 0; id < nd; id++) { - if (mparams->tensor_split[id] != 0.0f) { - throw llama_params_fit_exception("model_params::tensor_split already set by user, abort"); - } - } - } - if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { - throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); - } - } - if (!tensor_buft_overrides) { - throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort"); - } - if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { - throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort"); - } - - // step 3: iteratively fill the back to front with "dense" layers - // - for a dense model simply fill full layers, giving each device a contiguous slice of the model - // - for a MoE model, same as dense model but with all MoE tensors in system memory - - // utility function that returns a static C string matching the tensors for a specific layer index and layer fraction: - auto get_overflow_pattern = [&](const size_t il, const layer_fraction_t lf) -> const char * { - constexpr size_t n_strings = 1000; - if (il >= n_strings) { - throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported"); - } - switch (lf) { - case LAYER_FRACTION_ATTN: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|up|gate_up|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_UP: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|gate_up|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_GATE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_MOE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate_up|gate)_(ch|)exps"; - } - return patterns[il].c_str(); - } - default: - GGML_ABORT("fatal error"); - } - }; - - struct ngl_t { - uint32_t n_layer = 0; // number of total layers - uint32_t n_part = 0; // number of partial layers, <= n_layer - - // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: - layer_fraction_t overflow_type = LAYER_FRACTION_MOE; - - uint32_t n_full() const { - assert(n_layer >= n_part); - return n_layer - n_part; - } - }; - - const size_t ntbo = llama_max_tensor_buft_overrides(); - - // utility function to set n_gpu_layers and tensor_split - auto set_ngl_tensor_split_tbo = [&]( - const std::vector & ngl_per_device, - const std::vector & overflow_bufts, - llama_model_params & mparams) { - mparams.n_gpu_layers = 0; - for (size_t id = 0; id < nd; id++) { - mparams.n_gpu_layers += ngl_per_device[id].n_layer; - if (nd > 1) { - tensor_split[id] = ngl_per_device[id].n_layer; - } - } - assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); - uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides - - mparams.tensor_split = tensor_split; - - size_t itbo = 0; - for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_full(); - for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { - if (itbo + 1 >= ntbo) { - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == " - + std::to_string(ntbo) + " is insufficient for model"); - } - tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); - itbo++; - } - il0 += ngl_per_device[id].n_part; - } - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - }; - - // utility function that returns the memory use per device for given numbers of layers per device - auto get_memory_for_layers = [&]( - const char * func_name, - const std::vector & ngl_per_device, - const std::vector & overflow_bufts) -> std::vector { - llama_model_params mparams_copy = *mparams; - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); - - const dmds_t dmd_nl = llama_get_device_memory_data( - path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - LLAMA_LOG_DEBUG("%s: memory for test allocation by device:\n", func_name); - for (size_t id = 0; id < nd; id++) { - const ngl_t & n = ngl_per_device[id]; - LLAMA_LOG_DEBUG( - "%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n", - func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB); - } - - std::vector ret; - ret.reserve(nd); - for (const llama_device_memory_data & dmd : dmd_nl) { - ret.push_back(dmd.mb.total()); - } - return ret; - }; - - int64_t global_surplus_cpu_moe = 0; - if (hp_nex > 0) { - const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate_up|gate)_(ch|)exps"; // matches all MoE tensors - ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); - tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; - tensor_buft_overrides[1] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - - LLAMA_LOG_DEBUG("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__); - const dmds_t dmds_cpu_moe = llama_get_device_memory_data( - path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - for (size_t id = 0; id < nd; id++) { - global_surplus_cpu_moe += dmds_cpu_moe[id].free; - global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; - } - - if (global_surplus_cpu_moe > 0) { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n", - __func__, global_surplus_cpu_moe/MiB); - } else { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n", - __func__, -global_surplus_cpu_moe/MiB); - } - - // reset - tensor_buft_overrides[0] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - } - - std::vector targets; // maximum acceptable memory use per device - targets.reserve(nd); - for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margins[id]); - LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); - } - - std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: - overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd; id++) { - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); - } - - std::vector ngl_per_device(nd); - std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - - // optimize the number of layers per device using the method of false position: - // - ngl_per_device has 0 layers for each device, lower bound - // - try a "high" configuration where a device is given all unassigned layers - // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target - // - check memory use of our guess, replace either the low or high bound - // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits - // - the last device has the output layer, which cannot be a partial layer - if (hp_nex == 0) { - LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); - } else { - LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); - } - for (int id = nd - 1; id >= 0; id--) { - uint32_t n_unassigned = hp_ngl + 1; - for (size_t jd = id + 1; jd < nd; ++jd) { - assert(n_unassigned >= ngl_per_device[jd].n_layer); - n_unassigned -= ngl_per_device[jd].n_layer; - } - - std::vector ngl_per_device_high = ngl_per_device; - ngl_per_device_high[id].n_layer = n_unassigned; - if (hp_nex > 0) { - ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; - } - if (ngl_per_device_high[id].n_layer > 0) { - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); - uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - ngl_per_device_test[id].n_layer += step_size; - if (hp_nex) { - ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? - step_size - 1 : step_size; // the first layer is the output layer which must always be full - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer); - } - delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - } - } else { - assert(ngl_per_device_high[id].n_layer == n_unassigned); - ngl_per_device = ngl_per_device_high; - mem = mem_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); - } - if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); - return; - } - - // step 4: for a MoE model where all dense tensors fit, - // convert the dense-only layers in the back to full layers in the front until all devices are full - // essentially the same procedure as for the dense-only layers except front-to-back - // also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM - - size_t id_dense_start = nd; - for (int id = nd - 1; id >= 0; id--) { - if (ngl_per_device[id].n_layer > 0) { - id_dense_start = id; - continue; - } - break; - } - assert(id_dense_start < nd); - - LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { - std::vector ngl_per_device_high = ngl_per_device; - for (size_t jd = id_dense_start; jd < nd; jd++) { - const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; - ngl_per_device_high[id].n_layer += n_layer_move; - ngl_per_device_high[jd].n_layer -= n_layer_move; - ngl_per_device_high[jd].n_part = 0; - } - size_t id_dense_start_high = nd - 1; - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - uint32_t n_converted_test = 0; - for (;id_dense_start_test < nd; id_dense_start_test++) { - const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part); - ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd; - ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd; - ngl_per_device_test[id].n_layer += n_convert_jd; - n_converted_test += n_convert_jd; - - if (ngl_per_device_test[id_dense_start_test].n_part > 0) { - break; - } - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - id_dense_start_high = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", - __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); - } - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - } - } else { - ngl_per_device = ngl_per_device_high; - mem = mem_high; - id_dense_start = id_dense_start_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - - // try to fit at least part of one more layer - if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - ngl_per_device_test[id_dense_start_test].n_layer--; - ngl_per_device_test[id_dense_start_test].n_part--; - ngl_per_device_test[id].n_layer++; - ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_part == 0) { - id_dense_start_test++; - } - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; - std::vector overflow_bufts_test = overflow_bufts; - if (id < nd - 1) { - overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1].dev); - } - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } else { - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - // print info for devices that were not changed during the conversion from dense only to full layers: - for (size_t id = id_dense_start + 1; id < nd; id++) { - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); -} - -enum llama_params_fit_status llama_params_fit( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { - const int64_t t0_us = llama_time_us(); - llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; - try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); - LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); - } catch (const llama_params_fit_exception & e) { - LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_FAILURE; - } catch (const std::runtime_error & e) { - LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_ERROR; - } - const int64_t t1_us = llama_time_us(); - LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); - return status; -} - struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, diff --git a/src/models/apertus.cpp b/src/models/apertus.cpp index 80e63e3b459..af44cea6054 100644 --- a/src/models/apertus.cpp +++ b/src/models/apertus.cpp @@ -50,7 +50,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/arcee.cpp b/src/models/arcee.cpp index 948df17d809..2e71f5d9e2a 100644 --- a/src/models/arcee.cpp +++ b/src/models/arcee.cpp @@ -55,7 +55,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/bailingmoe.cpp b/src/models/bailingmoe.cpp index 4a6969b9789..67a7120d622 100644 --- a/src/models/bailingmoe.cpp +++ b/src/models/bailingmoe.cpp @@ -48,7 +48,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } diff --git a/src/models/bailingmoe2.cpp b/src/models/bailingmoe2.cpp index 016072a9695..497b4babd0c 100644 --- a/src/models/bailingmoe2.cpp +++ b/src/models/bailingmoe2.cpp @@ -48,7 +48,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 57916c8aeb8..7e046cfd2a4 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -72,7 +72,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/src/models/bitnet.cpp b/src/models/bitnet.cpp index 257cf4ca4ea..71526354ca6 100644 --- a/src/models/bitnet.cpp +++ b/src/models/bitnet.cpp @@ -57,8 +57,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(cur, "attn_sub_norm", il); cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); + if (model.layers[il].wo_b) { + cur = ggml_add(ctx0, cur, model.layers[il].wo_b); } cb(cur, "attn_out", il); } diff --git a/src/models/bloom.cpp b/src/models/bloom.cpp index cf188211dfd..f3b0999bf54 100644 --- a/src/models/bloom.cpp +++ b/src/models/bloom.cpp @@ -33,7 +33,7 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/codeshell.cpp b/src/models/codeshell.cpp index 5efa087e798..3ceb5835b85 100644 --- a/src/models/codeshell.cpp +++ b/src/models/codeshell.cpp @@ -47,7 +47,7 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/cohere2-iswa.cpp b/src/models/cohere2-iswa.cpp index bf39edc0deb..670b08e7d97 100644 --- a/src/models/cohere2-iswa.cpp +++ b/src/models/cohere2-iswa.cpp @@ -58,7 +58,7 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/command-r.cpp b/src/models/command-r.cpp index fb10eac9c9f..067961caa08 100644 --- a/src/models/command-r.cpp +++ b/src/models/command-r.cpp @@ -54,7 +54,7 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/deci.cpp b/src/models/deci.cpp index ed52d2b9942..30272eabd69 100644 --- a/src/models/deci.cpp +++ b/src/models/deci.cpp @@ -59,7 +59,7 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/deepseek.cpp b/src/models/deepseek.cpp index 73667cd665a..671b72dfead 100644 --- a/src/models/deepseek.cpp +++ b/src/models/deepseek.cpp @@ -49,7 +49,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/dflash.cpp b/src/models/dflash.cpp new file mode 100644 index 00000000000..0adba127eab --- /dev/null +++ b/src/models/dflash.cpp @@ -0,0 +1,161 @@ +#include "models.h" + +ggml_tensor * llm_build_dflash_encode::build_inp_embd() const { + const int64_t n_target_layer_ids = (int64_t) hparams.dflash_target_layer_ids.size(); + const int64_t n_embd_target_features = n_target_layer_ids * n_embd; + + auto inp_target = std::make_unique(n_embd_target_features); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens); + ggml_set_input(inp_target->embd); + + ggml_tensor * cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + +llm_build_dflash_encode::llm_build_dflash_encode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = build_inp_embd(); + + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); + + cur = build_norm(cur, model.dflash_hidden_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "hidden_norm_out", -1); + + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} + +llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // Noise tokens [MASK] + GGML_ASSERT(model.target_tok_embd != nullptr && "DFlash decoder requires target model's tok_embd"); + ggml_tensor * noise_embd = build_inp_embd(model.target_tok_embd); + cb(noise_embd, "inp_noise_embd", -1); + + // Target context via llama_cross (filled from accumulated_target_ctx), graph rebuilds every step + ggml_tensor * target_ctx = build_inp_cross_embd(); + const int64_t n_ctx = target_ctx->ne[1]; + + ggml_tensor * inpL = noise_embd; + + const int64_t n_tokens_kv = n_ctx + n_tokens; + + // Position tensor covering target_ctx + noise + ggml_tensor * inp_pos_full = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens_kv); + ggml_set_input(inp_pos_full); + cb(inp_pos_full, "inp_pos_full", -1); + + // Q positions: last n_tokens entries (noise only) + ggml_tensor * inp_pos_q = ggml_view_1d(ctx0, inp_pos_full, n_tokens, + n_ctx * ggml_element_size(inp_pos_full)); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + ggml_tensor * noise_norm = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); + cb(noise_norm, "noise_norm", il); + + // Q from noise only + ggml_tensor * Qcur = build_lora_mm(layer.wq, noise_norm); + if (layer.wq_b) { Qcur = ggml_add(ctx0, Qcur, layer.wq_b); } + cb(Qcur, "Qcur", il); + + // K = concat(k_proj(target_ctx), k_proj(noise)) + ggml_tensor * K_tgt = build_lora_mm(layer.wk, target_ctx); + ggml_tensor * K_noise = build_lora_mm(layer.wk, noise_norm); + if (layer.wk_b) { + K_tgt = ggml_add(ctx0, K_tgt, layer.wk_b); + K_noise = ggml_add(ctx0, K_noise, layer.wk_b); + } + ggml_tensor * Kcur = ggml_concat(ctx0, K_tgt, K_noise, 1); + cb(Kcur, "Kcur", il); + + // V = concat(v_proj(target_ctx), v_proj(noise)) + ggml_tensor * V_tgt = build_lora_mm(layer.wv, target_ctx); + ggml_tensor * V_noise = build_lora_mm(layer.wv, noise_norm); + if (layer.wv_b) { + V_tgt = ggml_add(ctx0, V_tgt, layer.wv_b); + V_noise = ggml_add(ctx0, V_noise, layer.wv_b); + } + ggml_tensor * Vcur = ggml_concat(ctx0, V_tgt, V_noise, 1); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens_kv); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens_kv); + + Qcur = build_norm(Qcur, layer.attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + // RoPE: K uses full positions [0..n_ctx+n_tokens-1], Q uses last n_tokens + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos_full, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur_rope", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos_q, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_rope", il); + + // Full attention (no causal mask) + ggml_build_forward_expand(gf, Qcur); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + + ggml_tensor * cur = build_attn_mha(Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "kqv_out", il); + + cur = build_lora_mm(layer.wo, cur); + if (layer.wo_b) { cur = ggml_add(ctx0, cur, layer.wo_b); } + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "attn_res", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, NULL, NULL, + layer.ffn_gate, NULL, NULL, + layer.ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + inpL = cur; + } + + ggml_tensor * cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + if (model.target_output) { + cur = build_lora_mm(model.target_output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + } + + ggml_build_forward_expand(gf, cur); +} \ No newline at end of file diff --git a/src/models/dots1.cpp b/src/models/dots1.cpp index f1668fe6284..5d1750fedda 100644 --- a/src/models/dots1.cpp +++ b/src/models/dots1.cpp @@ -49,7 +49,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/dream.cpp b/src/models/dream.cpp index ad6608b56f9..8e7d9ae64c7 100644 --- a/src/models/dream.cpp +++ b/src/models/dream.cpp @@ -43,7 +43,7 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/eagle3.cpp b/src/models/eagle3.cpp new file mode 100644 index 00000000000..69ac3be8c59 --- /dev/null +++ b/src/models/eagle3.cpp @@ -0,0 +1,186 @@ +#include "models.h" + +ggml_tensor * llm_build_eagle3_encode::build_inp_embd() const { + const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; + + ggml_tensor * cur = nullptr; + + // Input: Target model features (3 layers concatenated: low, mid, high) + // Data will be provided via ubatch->embd in encode_eagle3_features() + auto inp_target = std::make_unique(n_embd_target_features); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens); + ggml_set_input(inp_target->embd); + + cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + +// EAGLE3 Encoder: processes target model features through feature fusion layer +// Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high +// Output: g_embeddings e.g. [4096, n_tokens] stored in context +llm_build_eagle3_encode::llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = nullptr; + + cur = build_inp_embd(); + + // Feature fusion layer + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); + + // Output: g_embeddings e.g. [4096, n_tokens] + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} + +// EAGLE3 Decoder: processes draft tokens using g_embeddings from encoder +// Input: draft tokens + g_embeddings from encoder +// Output: draft logits +llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_layer == 1); // EAGLE-3 has only one decoder layer + + ggml_tensor * cur; + ggml_tensor * inpL; + + // EAGLE3 Decoder receives: + // 1. Token embeddings (e.g.from EAGLE3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) + // 2. g_embeddings from encoder + // Choose token_embd_eagle3: prefer EAGLE3's own if available (Llama 3.3 70B), else use target's (Llama 3.1 8B) + ggml_tensor * token_embd_eagle3 = (model.tok_embd != nullptr) ? model.tok_embd : model.target_tok_embd; + GGML_ASSERT(token_embd_eagle3 != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); + ggml_tensor * inp_embd = build_inp_embd(token_embd_eagle3); + cb(inp_embd, "inp_embd", -1); + + // TODO: refactor into llm_graph_input + ggml_tensor * inp_g = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp_g); + cb(inp_g, "inp_g_embeddings", -1); // TODO: do not change the name! refactor into llm_graph_input + + inpL = inp_g; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Single decoder layer (il = 0) + const int il = 0; + { + // Apply input_layernorm to the token embeddings + ggml_tensor * embd_norm = build_norm(inp_embd, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(embd_norm, "embd_norm", il); + + // Apply hidden_norm to inp_g + ggml_tensor * g_norm = build_norm(inp_g, + model.layers[il].eagle3_hidden_norm, NULL, + LLM_NORM_RMS, -1); + cb(g_norm, "g_norm", il); + + // norm_before_residual: determines what goes into the residual connection (compatible with Readhat eagle3 speculator model) + // - false (default): use raw inp_g for residual + // - true: use normalized g_norm for residual + // inpL is the concatenated input (normalized inp_embd + normalized inp_g) + ggml_tensor * inpSA = hparams.eagle3_norm_before_residual ? g_norm : inpL; + + // Concatenate normalized inp_embd and normalized inp_g + cur = ggml_concat(ctx0, embd_norm, g_norm, il); + cb(cur, "concat_embd", il); + + // Self-attention with concatenated input + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // rope freq factors, returns nullptr if not available + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + + if (inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Add residual and update it + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Apply FFN norm to the sum + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + // Output norm with residual + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "eagle3_prenorm", il); + + inpL = cur; + } + + cur = inpL; + + // Output prenorm state (for next token's g_embeddings in autoregressive generation) + ggml_set_output(cur); + res->t_embd = cur; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head - projects to draft vocabulary + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/exaone.cpp b/src/models/exaone.cpp index 626056e4d6d..4f845bf4106 100644 --- a/src/models/exaone.cpp +++ b/src/models/exaone.cpp @@ -46,7 +46,7 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/gpt2.cpp b/src/models/gpt2.cpp index 22e7d7f415c..f8dc53eb723 100644 --- a/src/models/gpt2.cpp +++ b/src/models/gpt2.cpp @@ -37,7 +37,7 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/gptneox.cpp b/src/models/gptneox.cpp index 87010841a17..0016ddede43 100644 --- a/src/models/gptneox.cpp +++ b/src/models/gptneox.cpp @@ -46,7 +46,7 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/granite-hybrid.cpp b/src/models/granite-hybrid.cpp index d6e0e8d9374..e983742bef5 100644 --- a/src/models/granite-hybrid.cpp +++ b/src/models/granite-hybrid.cpp @@ -92,7 +92,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/src/models/granite.cpp b/src/models/granite.cpp index 7b42142c067..6ea90285225 100644 --- a/src/models/granite.cpp +++ b/src/models/granite.cpp @@ -101,7 +101,7 @@ ggml_tensor * llm_build_granite::build_attention_layer( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/src/models/grok.cpp b/src/models/grok.cpp index 69eccb94b7b..b8f35afdc03 100644 --- a/src/models/grok.cpp +++ b/src/models/grok.cpp @@ -50,7 +50,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/grovemoe.cpp b/src/models/grovemoe.cpp index 7806a02c400..151108a2a71 100644 --- a/src/models/grovemoe.cpp +++ b/src/models/grovemoe.cpp @@ -50,7 +50,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/src/models/hunyuan-dense.cpp b/src/models/hunyuan-dense.cpp index 97f5da8ee90..1cd85d6d9d4 100644 --- a/src/models/hunyuan-dense.cpp +++ b/src/models/hunyuan-dense.cpp @@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; @@ -37,22 +42,36 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); @@ -64,7 +83,7 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/hunyuan-moe.cpp b/src/models/hunyuan-moe.cpp index 0e32b7d5e86..ffe1664b0e1 100644 --- a/src/models/hunyuan-moe.cpp +++ b/src/models/hunyuan-moe.cpp @@ -65,7 +65,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/internlm2.cpp b/src/models/internlm2.cpp index 5f688840e3f..83be2ca0aee 100644 --- a/src/models/internlm2.cpp +++ b/src/models/internlm2.cpp @@ -50,7 +50,7 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/jais.cpp b/src/models/jais.cpp index 0f817c1d8b9..31101f3c14b 100644 --- a/src/models/jais.cpp +++ b/src/models/jais.cpp @@ -27,7 +27,7 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/jais2.cpp b/src/models/jais2.cpp index 30abe8bc0de..507e04fa4aa 100644 --- a/src/models/jais2.cpp +++ b/src/models/jais2.cpp @@ -51,7 +51,7 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 3f8caeef8b8..19ed4d07cb9 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -31,6 +31,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -70,7 +80,7 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); if (model.layers[il].wo_s) { cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); diff --git a/src/models/llama4.cpp b/src/models/llama4.cpp index d40d37a9248..4e4bfb43f33 100644 --- a/src/models/llama4.cpp +++ b/src/models/llama4.cpp @@ -84,7 +84,7 @@ llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_gr cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/maincoder.cpp b/src/models/maincoder.cpp index 1e25d50fa7d..8a76931c007 100644 --- a/src/models/maincoder.cpp +++ b/src/models/maincoder.cpp @@ -56,7 +56,7 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp index 8e0e13a7452..b5ae72a2ee1 100644 --- a/src/models/mistral3.cpp +++ b/src/models/mistral3.cpp @@ -67,7 +67,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/models.h b/src/models/models.h index 94991c55fe8..062e6ff621d 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -202,6 +202,26 @@ struct llm_build_dream : public llm_graph_context { llm_build_dream(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_eagle3_encode : public llm_graph_context { + llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_inp_embd() const; +}; + +struct llm_build_eagle3_decode : public llm_graph_context { + llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params); +}; + +struct llm_build_dflash_encode : public llm_graph_context { + llm_build_dflash_encode(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_inp_embd() const; +}; + +struct llm_build_dflash_decode : public llm_graph_context { + llm_build_dflash_decode(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_ernie4_5 : public llm_graph_context { llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/mpt.cpp b/src/models/mpt.cpp index 7a7169a7515..8596bbb2024 100644 --- a/src/models/mpt.cpp +++ b/src/models/mpt.cpp @@ -56,7 +56,7 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 66eb0bdb956..dc07d43df58 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -70,7 +70,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/src/models/nemotron.cpp b/src/models/nemotron.cpp index 09ec2936be6..054b16fe0ef 100644 --- a/src/models/nemotron.cpp +++ b/src/models/nemotron.cpp @@ -51,7 +51,7 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/openai-moe-iswa.cpp b/src/models/openai-moe-iswa.cpp index e7b7a2bc8af..db7492a6836 100644 --- a/src/models/openai-moe-iswa.cpp +++ b/src/models/openai-moe-iswa.cpp @@ -19,6 +19,30 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } + + // DFlash: Extract intermediate layer features from target model at layer INPUT + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, @@ -48,7 +72,7 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); cb(cur, "attn_out", il); diff --git a/src/models/paddleocr.cpp b/src/models/paddleocr.cpp index 4bc74c175e7..56cb1d94c5f 100644 --- a/src/models/paddleocr.cpp +++ b/src/models/paddleocr.cpp @@ -55,7 +55,7 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/src/models/pangu-embedded.cpp b/src/models/pangu-embedded.cpp index 8046750d048..53464f21d22 100644 --- a/src/models/pangu-embedded.cpp +++ b/src/models/pangu-embedded.cpp @@ -49,7 +49,7 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/phi2.cpp b/src/models/phi2.cpp index 8181afd343d..0fb3ffa2e63 100644 --- a/src/models/phi2.cpp +++ b/src/models/phi2.cpp @@ -51,7 +51,7 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/phi3.cpp b/src/models/phi3.cpp index e00a517c78c..39af285d3c5 100644 --- a/src/models/phi3.cpp +++ b/src/models/phi3.cpp @@ -60,7 +60,7 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ cb(Qcur, "Qcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen2.cpp b/src/models/qwen2.cpp index f0c0553d3dc..2892dd75087 100644 --- a/src/models/qwen2.cpp +++ b/src/models/qwen2.cpp @@ -50,7 +50,7 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen2moe.cpp b/src/models/qwen2moe.cpp index 166a8fb2fb9..5f0a6861b68 100644 --- a/src/models/qwen2moe.cpp +++ b/src/models/qwen2moe.cpp @@ -50,7 +50,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen2vl.cpp b/src/models/qwen2vl.cpp index 47dfc92a18e..da7937c7667 100644 --- a/src/models/qwen2vl.cpp +++ b/src/models/qwen2vl.cpp @@ -53,7 +53,7 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen3.cpp b/src/models/qwen3.cpp index 68149bfca95..fa8c3940226 100644 --- a/src/models/qwen3.cpp +++ b/src/models/qwen3.cpp @@ -21,6 +21,31 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } + + // DFlash: Extract intermediate layer features from target model at layer INPUT + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -56,7 +81,7 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (model.layers[il].wo_s) { cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 87790f08e4e..19d3d95619d 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -26,6 +26,20 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // DFlash: Extract intermediate layer features from target model + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 7dc6a23c751..b367bdecf36 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -26,6 +26,20 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // DFlash: Extract intermediate layer features from target model + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); diff --git a/src/models/qwen3moe.cpp b/src/models/qwen3moe.cpp index 533e64b4366..bdb67332e81 100644 --- a/src/models/qwen3moe.cpp +++ b/src/models/qwen3moe.cpp @@ -21,6 +21,17 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -56,7 +67,7 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (model.layers[il].wo_s) { cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); diff --git a/src/models/qwen3vl-moe.cpp b/src/models/qwen3vl-moe.cpp index fe5ef578f33..29ee8278a4d 100644 --- a/src/models/qwen3vl-moe.cpp +++ b/src/models/qwen3vl-moe.cpp @@ -62,7 +62,7 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/qwen3vl.cpp b/src/models/qwen3vl.cpp index 333dba6eae0..faa5f2ef3c8 100644 --- a/src/models/qwen3vl.cpp +++ b/src/models/qwen3vl.cpp @@ -62,7 +62,7 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/rnd1.cpp b/src/models/rnd1.cpp index b53c075f5eb..a917c19f25a 100644 --- a/src/models/rnd1.cpp +++ b/src/models/rnd1.cpp @@ -58,7 +58,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/seed-oss.cpp b/src/models/seed-oss.cpp index 82c71d8df1d..6db8d9781fe 100644 --- a/src/models/seed-oss.cpp +++ b/src/models/seed-oss.cpp @@ -52,7 +52,7 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/smallthinker.cpp b/src/models/smallthinker.cpp index 5d9cc82f8f9..55d09ec325d 100644 --- a/src/models/smallthinker.cpp +++ b/src/models/smallthinker.cpp @@ -59,7 +59,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cb(Kcur, "Kcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/smollm3.cpp b/src/models/smollm3.cpp index 6600abcda75..83636dbf546 100644 --- a/src/models/smollm3.cpp +++ b/src/models/smollm3.cpp @@ -55,7 +55,7 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/starcoder.cpp b/src/models/starcoder.cpp index be4af1f5a31..cf9fe95c35b 100644 --- a/src/models/starcoder.cpp +++ b/src/models/starcoder.cpp @@ -36,7 +36,7 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_gr n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/starcoder2.cpp b/src/models/starcoder2.cpp index 1fa50b985c0..b6d4d5aac1a 100644 --- a/src/models/starcoder2.cpp +++ b/src/models/starcoder2.cpp @@ -50,7 +50,7 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/t5.cpp b/src/models/t5.cpp index 7675532b2d2..9f9dfef4012 100644 --- a/src/models/t5.cpp +++ b/src/models/t5.cpp @@ -41,7 +41,7 @@ llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_par ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); cur = build_attn(inp_attn_self, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b282c3239f0..edb585b9f65 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -155,6 +155,8 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) llama_build_and_test(test-chat.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) + target_include_directories(test-chat PRIVATE ${PROJECT_SOURCE_DIR}/tools/server) + target_link_libraries(test-chat PRIVATE server-context) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 828a9c14a45..71601131671 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3522,6 +3522,40 @@ struct test_add_rms_norm : public test_case { } }; +// GGML_OP_UNARY(RELU) + GGML_OP_SQR (fused operation) +struct test_relu_sqr : public test_case { + const ggml_type type; + const std::array ne; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "RELU_SQR"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_relu_sqr(ggml_type type = GGML_TYPE_F32, + std::array ne = {128, 2, 2, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * r = ggml_relu(ctx, a); + ggml_set_name(r, "relu"); + + ggml_tensor * out = ggml_sqr(ctx, r); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { const ggml_type type; @@ -7311,6 +7345,12 @@ static std::vector> make_test_cases_eval() { } } + // fused relu + sqr (squared ReLU) + for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { + test_cases.emplace_back(new test_relu_sqr(type, { 128, 2, 2, 2 })); + test_cases.emplace_back(new test_relu_sqr(type, { 5, 7, 11, 13 })); + } + // glu ops for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (int v : {0, 1}) { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 3b8de5ce02e..52b480c24e9 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -7,6 +7,7 @@ // #include "../src/llama-grammar.h" #include "../src/unicode.h" +#include "../tools/server/server-chat.h" #include "chat-auto-parser.h" #include "chat.h" #include "common.h" @@ -1514,6 +1515,117 @@ static void test_tools_oaicompat_json_conversion() { common_chat_tools_to_json_oaicompat({ special_function_tool }).dump(2)); } +static void test_convert_responses_to_chatcmpl() { + LOG_DBG("%s\n", __func__); + + // Test basic conversion with input messages (user/assistant alternating) + { + json input = json::parse(R"({ + "input": [ + { + "type": "message", + "role": "user", + "content": "hi wassup" + }, + { + "type": "message", + "role": "assistant", + "content": "Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?" + }, + { + "type": "message", + "role": "user", + "content": "hi" + } + ], + "model": "gpt-5-mini", + "stream": false, + "text": {}, + "reasoning": { + "effort": "medium" + } + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + // Verify messages were converted correctly + assert_equals(true, result.contains("messages")); + assert_equals(true, result.at("messages").is_array()); + assert_equals((size_t)3, result.at("messages").size()); + + // Check first message (user) + const auto & msg0 = result.at("messages")[0]; + assert_equals(std::string("user"), msg0.at("role").get()); + assert_equals(true, msg0.at("content").is_array()); + assert_equals(std::string("text"), msg0.at("content")[0].at("type").get()); + assert_equals(std::string("hi wassup"), msg0.at("content")[0].at("text").get()); + + // Check second message (assistant) + const auto & msg1 = result.at("messages")[1]; + assert_equals(std::string("assistant"), msg1.at("role").get()); + assert_equals(true, msg1.at("content").is_array()); + assert_equals(std::string("text"), msg1.at("content")[0].at("type").get()); + assert_equals(std::string("Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?"), msg1.at("content")[0].at("text").get()); + + // Check third message (user) + const auto & msg2 = result.at("messages")[2]; + assert_equals(std::string("user"), msg2.at("role").get()); + assert_equals(true, msg2.at("content").is_array()); + assert_equals(std::string("text"), msg2.at("content")[0].at("type").get()); + assert_equals(std::string("hi"), msg2.at("content")[0].at("text").get()); + + // Verify other fields preserved + assert_equals(std::string("gpt-5-mini"), result.at("model").get()); + assert_equals(false, result.at("stream").get()); + } + + // Test string input + { + json input = json::parse(R"({ + "input": "Hello, world!", + "model": "test-model" + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals((size_t)1, result.at("messages").size()); + const auto & msg = result.at("messages")[0]; + assert_equals(std::string("user"), msg.at("role").get()); + assert_equals(std::string("Hello, world!"), msg.at("content").get()); + } + + // Test with instructions (system message) + { + json input = json::parse(R"({ + "input": "Hello", + "instructions": "You are a helpful assistant.", + "model": "test-model" + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals((size_t)2, result.at("messages").size()); + const auto & sys_msg = result.at("messages")[0]; + assert_equals(std::string("system"), sys_msg.at("role").get()); + assert_equals(std::string("You are a helpful assistant."), sys_msg.at("content").get()); + } + + // Test with max_output_tokens conversion + { + json input = json::parse(R"({ + "input": "Hello", + "model": "test-model", + "max_output_tokens": 100 + })"); + + json result = server_chat_convert_responses_to_chatcmpl(input); + + assert_equals(true, result.contains("max_tokens")); + assert_equals(false, result.contains("max_output_tokens")); + assert_equals(100, result.at("max_tokens").get()); + } +} + static void test_template_output_peg_parsers(bool detailed_debug) { LOG_DBG("%s\n", __func__); @@ -1796,7 +1908,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "\n" "\n1\n\n" "\n" - "") + "\n") .enable_thinking(false) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ special_function_tool }) @@ -1809,7 +1921,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "\n" "\n1\n\n" "\n" - "") + "\n") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ special_function_tool }) .expect(message_assist_call_thoughts) @@ -1826,7 +1938,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "\n1\n\n" "\n2\n\n" "\n" - "") + "\n") .enable_thinking(false) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .parallel_tool_calls(true) @@ -1849,7 +1961,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "") + "\n") .enable_thinking(false) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ @@ -1892,7 +2004,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "" + "\n" ) .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) @@ -1908,7 +2020,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "") + "\n") .expect_tool_calls({ { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, }) @@ -3595,6 +3707,51 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); } + // Reka Edge + { + auto tst = peg_tester("models/templates/Reka-Edge.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?") + .enable_thinking(false) + .expect(message_assist) + .run(); + tst.test("I'm\nthinking\n\nHello, world!\nWhat's up?") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + tst.test("Hello, world!\nWhat's up?\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call_content) + .run(); + tst.test("I'm\nthinking\n\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n\n\n{\"name\": \"special_function_with_opt\", \"arguments\": {\"arg1\": 1, \"arg2\": 2}}\n") + .enable_thinking(false) + .parallel_tool_calls(true) + .tools({ special_function_tool, special_function_tool_with_optional_param }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg") + .enable_thinking(false) + .tools({ special_function_tool }) + .is_partial(true) + .expect(message_assist_call_cutoff_args) + .run(); + } + // Apriel 1.5 { auto tst = peg_tester("models/templates/unsloth-Apriel-1.5.jinja", detailed_debug); @@ -4077,6 +4234,55 @@ static void test_template_output_peg_parsers(bool detailed_debug) { } } +static void test_reka_edge_common_path() { + auto tmpls = read_templates("models/templates/Reka-Edge.jinja"); + + { + common_chat_templates_inputs inputs; + common_chat_msg system_msg; + system_msg.role = "system"; + system_msg.content = "Use tools when needed."; + + common_chat_msg tool_call_msg = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); + + common_chat_msg tool_msg; + tool_msg.role = "tool"; + tool_msg.tool_name = "special_function"; + tool_msg.tool_call_id = "call0"; + tool_msg.content = "Sunny"; + + inputs.messages = { system_msg, message_user, tool_call_msg, tool_msg, message_user }; + inputs.tools = { special_function_tool }; + inputs.enable_thinking = true; + inputs.add_generation_prompt = true; + + auto params = common_chat_templates_apply(tmpls.get(), inputs); + + if (params.prompt.find("\nSunny\n") == std::string::npos) { + throw std::runtime_error("Reka Edge prompt did not render tool response history"); + } + if (params.prompt.rfind("assistant: \n") == std::string::npos) { + throw std::runtime_error("Reka Edge prompt did not render thinking generation prompt"); + } + } + + { + common_chat_templates_inputs inputs; + inputs.messages = { + message_user, + simple_assist_msg("The first point is") + }; + inputs.add_generation_prompt = false; + inputs.enable_thinking = false; + inputs.chat_template_kwargs["continue_final_message"] = "true"; + + auto params = common_chat_templates_apply(tmpls.get(), inputs); + if (string_ends_with(params.prompt, "")) { + throw std::runtime_error("Reka Edge continue_final_message unexpectedly closed the assistant turn"); + } + } +} + // Test the developer role to system workaround with a simple mock template static void test_developer_role_to_system_workaround() { LOG_DBG("%s\n", __func__); @@ -4197,7 +4403,7 @@ int main(int argc, char ** argv) { bool detailed_debug = false; bool only_run_filtered = false; - // Check for --template flag + // Check for --template and --detailed flags for (int i = 1; i < argc; i++) { std::string arg = argv[i]; if (arg == "--template" && i + 1 < argc) { @@ -4222,7 +4428,20 @@ int main(int argc, char ** argv) { } #ifndef _WIN32 - if (argc > 1) { + // Check if any argument is a .jinja file (for template format detection mode) + bool has_jinja_files = false; + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "--detailed") { + continue; + } + if (arg.size() >= 6 && arg.rfind(".jinja") == arg.size() - 6) { + has_jinja_files = true; + break; + } + } + + if (has_jinja_files) { common_chat_templates_inputs inputs; common_chat_msg msg; msg.role = "user"; @@ -4255,7 +4474,9 @@ int main(int argc, char ** argv) { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); + test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); + test_reka_edge_common_path(); test_template_output_peg_parsers(detailed_debug); std::cout << "\n[chat] All tests passed!" << '\n'; } diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c index 7a0ce593c01..b49498c87c1 100644 --- a/tests/test-mtmd-c-api.c +++ b/tests/test-mtmd-c-api.c @@ -42,7 +42,7 @@ int main(void) { const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk); size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); // get position of the last token, which should be (nx - 1, ny - 1) - struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, n_tokens - 1); + struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, 0, n_tokens - 1); size_t nx = pos.x + 1; size_t ny = pos.y + 1; const char * id = mtmd_image_tokens_get_id(image_tokens); diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 79482a83170..369c24216b7 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "arg.h" #include "console.h" +#include "fit.h" // #include "log.h" #include "server-common.h" @@ -58,8 +59,6 @@ struct cli_context { std::vector input_files; task_params defaults; bool verbose_prompt; - int reasoning_budget = -1; - std::string reasoning_budget_message; // thread for showing "loading" animation std::atomic loading_show; @@ -76,8 +75,6 @@ struct cli_context { // defaults.return_progress = true; // TODO: show progress verbose_prompt = params.verbose_prompt; - reasoning_budget = params.reasoning_budget; - reasoning_budget_message = params.reasoning_budget_message; } std::string generate_completion(result_timings & out_timings) { @@ -105,7 +102,7 @@ struct cli_context { const llama_vocab * vocab = llama_model_get_vocab( llama_get_model(ctx_server.get_llama_context())); - task.params.sampling.reasoning_budget_tokens = reasoning_budget; + task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens; task.params.sampling.generation_prompt = chat_params.generation_prompt; if (!chat_params.thinking_start_tag.empty()) { @@ -115,7 +112,7 @@ struct cli_context { task.params.sampling.reasoning_budget_end = common_tokenize(vocab, chat_params.thinking_end_tag, false, true); task.params.sampling.reasoning_budget_forced = - common_tokenize(vocab, reasoning_budget_message + chat_params.thinking_end_tag, false, true); + common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true); } rd.post_task({std::move(task)}); @@ -206,6 +203,8 @@ struct cli_context { auto meta = ctx_server.get_meta(); auto & chat_params = meta.chat_params; + auto caps = common_chat_templates_get_caps(chat_params.tmpls.get()); + common_chat_templates_inputs inputs; inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = {}; // TODO @@ -213,7 +212,7 @@ struct cli_context { inputs.json_schema = ""; // TODO inputs.grammar = ""; // TODO inputs.use_jinja = chat_params.use_jinja; - inputs.parallel_tool_calls = false; + inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"]; inputs.add_generation_prompt = true; inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; inputs.force_pure_content = chat_params.force_pure_content; @@ -225,7 +224,7 @@ struct cli_context { }; // TODO?: Make this reusable, enums, docs -static const std::array cmds = { +static const std::array cmds = { "/audio ", "/clear", "/exit", @@ -239,19 +238,19 @@ static std::vector> auto_completion_callback(std: std::vector> matches; std::string cmd; - if (line.length() > 1 && line[0] == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](const std::string & prefix) { + if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) { return string_starts_with(line, prefix); })) { auto it = cmds.begin(); - while ((it = std::find_if(it, cmds.end(), [line](const std::string & cmd_line) { + while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) { return string_starts_with(cmd_line, line); })) != cmds.end()) { - matches.emplace_back(*it, (*it).length()); + matches.emplace_back(*it, it->length()); ++it; } } else { - auto it = std::find_if(cmds.begin(), cmds.end(), [line](const std::string & prefix) { + auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) { return prefix.back() == ' ' && string_starts_with(line, prefix); }); @@ -268,18 +267,18 @@ static std::vector> auto_completion_callback(std: std::string expanded_prefix = path_prefix; #if !defined(_WIN32) - if (string_starts_with(path_prefix, "~")) { + if (string_starts_with(path_prefix, '~')) { const char * home = std::getenv("HOME"); if (home && home[0]) { - expanded_prefix = std::string(home) + path_prefix.substr(1); + expanded_prefix = home + path_prefix.substr(1); } } - if (string_starts_with(expanded_prefix, "/")) { + if (string_starts_with(expanded_prefix, '/')) { #else if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) { #endif cur_dir = std::filesystem::path(expanded_prefix).parent_path(); - cur_dir_str = ""; + cur_dir_str.clear(); } else if (!path_prefix.empty()) { cur_dir /= std::filesystem::path(path_prefix).parent_path(); } @@ -302,7 +301,7 @@ static std::vector> auto_completion_callback(std: } if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) { - std::string updated_line = cmd + path_entry; + const std::string updated_line = cmd + path_entry; matches.emplace_back(updated_line + path_postfix, updated_line.length()); } @@ -312,7 +311,7 @@ static std::vector> auto_completion_callback(std: } if (matches.empty()) { - std::string updated_line = cmd + path_prefix; + const std::string updated_line = cmd + path_prefix; matches.emplace_back(updated_line + path_postfix, updated_line.length()); } @@ -329,7 +328,7 @@ static std::vector> auto_completion_callback(std: len = std::min(len, static_cast(cmp.first - match0.begin())); } - std::string updated_line = std::string(match0.substr(0, len)); + const std::string updated_line = std::string(match0.substr(0, len)); matches.emplace_back(updated_line + path_postfix, updated_line.length()); } @@ -566,10 +565,10 @@ int main(int argc, char ** argv) { if (endpath != std::string::npos) { std::string rel_pattern = pattern.substr(0, endpath); #if !defined(_WIN32) - if (string_starts_with(rel_pattern, "~")) { + if (string_starts_with(rel_pattern, '~')) { const char * home = std::getenv("HOME"); if (home && home[0]) { - rel_pattern = std::string(home) + rel_pattern.substr(1); + rel_pattern = home + rel_pattern.substr(1); } } #endif @@ -647,7 +646,7 @@ int main(int argc, char ** argv) { // bump the log level to display timings common_log_set_verbosity_thold(LOG_LEVEL_INFO); - llama_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context()); + common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context()); return 0; } diff --git a/tools/fit-params/fit-params.cpp b/tools/fit-params/fit-params.cpp index 3c0404ed309..bcdf4404016 100644 --- a/tools/fit-params/fit-params.cpp +++ b/tools/fit-params/fit-params.cpp @@ -1,14 +1,12 @@ #include "llama.h" +#include "../src/llama-ext.h" #include "arg.h" #include "common.h" +#include "fit.h" #include "log.h" -#include #include -#include - -using namespace std::chrono_literals; #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -19,49 +17,58 @@ int main(int argc, char ** argv) { common_init(); - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FIT_PARAMS)) { return 1; } llama_backend_init(); llama_numa_init(params.numa); + auto mparams = common_model_params_to_llama(params); auto cparams = common_context_params_to_llama(params); - const llama_params_fit_status status = llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, - params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); - if (status != LLAMA_PARAMS_FIT_STATUS_SUCCESS) { - LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__); - exit(1); - } - LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__); - common_log_flush(common_log_main()); - printf("-c %" PRIu32 " -ngl %" PRIi32, cparams.n_ctx, mparams.n_gpu_layers); + if (!params.fit_params_print) { + const common_params_fit_status status = common_fit_params(params.model.path.c_str(), &mparams, &cparams, + params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, + params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); + if (status != COMMON_PARAMS_FIT_STATUS_SUCCESS) { + LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__); + exit(1); + } - size_t nd = llama_max_devices(); - while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) { - nd--; - } - if (nd > 1) { - for (size_t id = 0; id < nd; id++) { - if (id == 0) { - printf(" -ts "); + LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__); + common_log_flush(common_log_main()); + printf("-c %" PRIu32 " -ngl %" PRIi32, cparams.n_ctx, mparams.n_gpu_layers); + + size_t nd = llama_max_devices(); + while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) { + nd--; + } + if (nd > 1) { + for (size_t id = 0; id < nd; id++) { + if (id == 0) { + printf(" -ts "); + } + printf("%s%" PRIu32, id > 0 ? "," : "", uint32_t(mparams.tensor_split[id])); } - printf("%s%" PRIu32, id > 0 ? "," : "", uint32_t(mparams.tensor_split[id])); } - } - const size_t ntbo = llama_max_tensor_buft_overrides(); - bool any_tbo = false; - for (size_t itbo = 0; itbo < ntbo && mparams.tensor_buft_overrides[itbo].pattern != nullptr; itbo++) { - if (itbo == 0) { - printf(" -ot \""); + const size_t ntbo = llama_max_tensor_buft_overrides(); + bool any_tbo = false; + for (size_t itbo = 0; itbo < ntbo && mparams.tensor_buft_overrides[itbo].pattern != nullptr; itbo++) { + if (itbo == 0) { + printf(" -ot \""); + } + printf("%s%s=%s", itbo > 0 ? "," : "", mparams.tensor_buft_overrides[itbo].pattern, ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft)); + any_tbo = true; } - printf("%s%s=%s", itbo > 0 ? "," : "", mparams.tensor_buft_overrides[itbo].pattern, ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft)); - any_tbo = true; + printf("%s\n", any_tbo ? "\"" : ""); + } else { + LOG_INF("%s: printing estimated memory in MiB to stdout (device, model, context, compute) ...\n", __func__); + common_log_flush(common_log_main()); + + common_fit_print(params.model.path.c_str(), &mparams, &cparams); } - printf("%s\n", any_tbo ? "\"" : ""); return 0; } diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 59920ab516b..e21a80e697b 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -22,6 +22,7 @@ #include "build-info.h" #include "common.h" #include "download.h" +#include "fit.h" #include "ggml.h" #include "llama.h" @@ -2225,7 +2226,7 @@ int main(int argc, char ** argv) { prev_inst = nullptr; } - // use default n_gpu_layers and n_ctx so llama_params_fit can adjust them + // use default n_gpu_layers and n_ctx so common_fit_params can adjust them mparams.n_gpu_layers = llama_model_default_params().n_gpu_layers; mparams.tensor_split = fit_tensor_split.data(); mparams.tensor_buft_overrides = fit_overrides.data(); @@ -2236,7 +2237,7 @@ int main(int argc, char ** argv) { uint32_t n_ctx_needed = inst.n_prompt + inst.n_gen + inst.n_depth; cparams.n_ctx = std::max(cparams.n_ctx, n_ctx_needed); - llama_params_fit(inst.model.c_str(), &mparams, &cparams, + common_fit_params(inst.model.c_str(), &mparams, &cparams, fit_tensor_split.data(), fit_overrides.data(), margins.data(), diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 399876128ef..35d721d5a4c 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -40,6 +40,7 @@ add_library(mtmd models/deepseekocr.cpp models/mobilenetv5.cpp models/youtuvl.cpp + models/yasa2.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 17cb703f7fb..7d6484eea85 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -150,7 +150,7 @@ #define TN_TOK_BOI "v.boi" #define TN_TOK_EOI "v.eoi" -// hunyuanocr +// hunyuanocr / hunyuanvl (shared GGUF tensor names) #define TN_MM_PRE_NORM "mm.pre_norm.%s" #define TN_TOK_IMG_BEGIN "mm.image_begin" #define TN_TOK_IMG_END "mm.image_end" @@ -242,6 +242,15 @@ #define TN_STD_BIAS "v.std_bias" #define TN_STD_SCALE "v.std_scale" +// yasa2 +#define TN_YASA_PATCH_LN_W "v.patch_ln.weight" +#define TN_YASA_PATCH_LN_B "v.patch_ln.bias" +#define TN_YASA_BACKBONE_LN_W "v.backbone_ln.weight" +#define TN_YASA_BACKBONE_LN_B "v.backbone_ln.bias" +#define TN_YASA_POS_EMBD "v.vision_pos_embed" +#define TN_YASA_STAGE_DOWN_LN "v.stage.%d.down.ln.%s" +#define TN_YASA_STAGE_DOWN_CONV "v.stage.%d.down.conv.%s" +#define TN_YASA_STAGE_BLK "v.stage.%d.blk.%d.%s.%s" // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -290,9 +299,11 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_YASA2, PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_HUNYUANOCR, + PROJECTOR_TYPE_HUNYUANVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -335,9 +346,11 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_YASA2, "yasa2"}, { PROJECTOR_TYPE_KIMIK25, "kimik25"}, { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, { PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"}, + { PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 9a93584d9be..bf8031b55b2 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -268,6 +268,27 @@ struct mobilenetv5_block { ggml_tensor * attn_norm_w = nullptr; }; +struct yasa2_block { + ggml_tensor * dw_w = nullptr; + ggml_tensor * dw_b = nullptr; + ggml_tensor * ln_w = nullptr; + ggml_tensor * ln_b = nullptr; + ggml_tensor * pw1_w = nullptr; + ggml_tensor * pw1_b = nullptr; + ggml_tensor * grn_w = nullptr; + ggml_tensor * grn_b = nullptr; + ggml_tensor * pw2_w = nullptr; + ggml_tensor * pw2_b = nullptr; +}; + +struct yasa2_stage { + ggml_tensor * down_ln_w = nullptr; + ggml_tensor * down_ln_b = nullptr; + ggml_tensor * down_conv_w = nullptr; + ggml_tensor * down_conv_b = nullptr; + std::vector blocks; +}; + struct clip_model { clip_modality modality = CLIP_MODALITY_VISION; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -402,6 +423,15 @@ struct clip_model { ggml_tensor * msfa_ffn_expand_bn = nullptr; ggml_tensor * msfa_ffn_project_bn = nullptr; + // yasa2 + ggml_tensor * yasa_patch_w = nullptr; + ggml_tensor * yasa_patch_b = nullptr; + ggml_tensor * yasa_patch_ln_w = nullptr; + ggml_tensor * yasa_patch_ln_b = nullptr; + ggml_tensor * yasa_backbone_ln_w = nullptr; + ggml_tensor * yasa_backbone_ln_b = nullptr; + ggml_tensor * yasa_vision_pos_embed = nullptr; + std::vector yasa_stages; // pixtral, glm4v ggml_tensor * token_embd_img_break = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f0e8786b660..45e39898d82 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -912,6 +912,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { builder = std::make_unique(ctx, img); } break; @@ -947,6 +948,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_YASA2: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1389,6 +1394,16 @@ struct clip_model_loader { hparams.set_limit_image_tokens(1, 62500); hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_YASA2: + { + hparams.ffn_op = FFN_GELU_ERF; + log_ffn_op = "gelu_erf"; + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC; + + // reka model performs better when using resize_bicubic, which stretches + // the image to fit fixed square size + hparams.image_resize_pad = false; + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1459,6 +1474,16 @@ struct clip_model_loader { get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); hparams.set_warmup_n_tokens(28*28); } break; + case PROJECTOR_TYPE_HUNYUANVL: + { + hparams.n_merge = 2; + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; + hparams.image_resize_pad = false; + hparams.ffn_op = FFN_GELU; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + hparams.set_limit_image_tokens(256, 16384); + hparams.set_warmup_n_tokens(32*32); + } break; case PROJECTOR_TYPE_LFM2A: { // audio preprocessing params @@ -1839,6 +1864,55 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_YASA2: + { + // reuse tensors already loaded by the common section + // (TN_PATCH_EMBD and TN_PATCH_BIAS have the same tensor names) + GGML_ASSERT(model.patch_embeddings_0 && "yasa2 requires v.patch_embd.weight"); + model.yasa_patch_w = model.patch_embeddings_0; + model.yasa_patch_b = model.patch_bias; + model.yasa_patch_ln_w = get_tensor(TN_YASA_PATCH_LN_W, false); + model.yasa_patch_ln_b = get_tensor(TN_YASA_PATCH_LN_B, false); + model.yasa_backbone_ln_w = get_tensor(TN_YASA_BACKBONE_LN_W, false); + model.yasa_backbone_ln_b = get_tensor(TN_YASA_BACKBONE_LN_B, false); + model.yasa_vision_pos_embed = get_tensor(TN_YASA_POS_EMBD, false); + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + + model.yasa_stages.clear(); + for (int s = 0; ; ++s) { + yasa2_stage stage; + stage.down_ln_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "weight"), false); + stage.down_ln_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "bias"), false); + stage.down_conv_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "weight"), false); + stage.down_conv_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "bias"), false); + + for (int bi = 0; ; ++bi) { + yasa2_block blk; + blk.dw_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "weight"), false); + if (!blk.dw_w) { + break; + } + blk.dw_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "bias"), false); + blk.ln_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "weight"), false); + blk.ln_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "bias"), false); + blk.pw1_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "weight"), false); + blk.pw1_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "bias"), false); + blk.grn_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "weight"), false); + blk.grn_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "bias"), false); + blk.pw2_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "weight"), false); + blk.pw2_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "bias"), false); + stage.blocks.push_back(blk); + } + + if (!stage.down_conv_w && stage.blocks.empty()) { + break; + } + model.yasa_stages.push_back(std::move(stage)); + } + } break; case PROJECTOR_TYPE_GLM4V: { model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); @@ -2159,6 +2233,7 @@ struct clip_model_loader { model.mm_eoi = get_tensor(TN_TOK_EOI); } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { // proj.0 -> mm.0 (conv1), proj.2 -> mm.2 (conv2), mlp -> mm.model.fc (linear) model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); @@ -2797,6 +2872,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; case PROJECTOR_TYPE_STEP3VL: @@ -2816,6 +2892,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: + case PROJECTOR_TYPE_HUNYUANVL: case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; case PROJECTOR_TYPE_STEP3VL: @@ -2843,6 +2920,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { // do nothing } break; + case PROJECTOR_TYPE_YASA2: + { + n_patches = 64; // adaptive average pooling to 8x8 tokens + } break; case PROJECTOR_TYPE_LDP: case PROJECTOR_TYPE_LDPV2: case PROJECTOR_TYPE_GLM_EDGE: @@ -3003,6 +3084,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = h * (h + 1) + 1; } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { int merge = ctx->model.hparams.n_merge; int ow = (img->nx / patch_size) / merge; @@ -3463,9 +3545,74 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_PHI4: case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_YASA2: { // do nothing } break; + case PROJECTOR_TYPE_HUNYUANVL: + { + // Compute the HunyuanVL 2D position embedding on CPU (with the + // custom sf=(target+0.1)/n_grid bilinear sampling that the + // reference implementation uses) and upload it to the graph + // input declared in clip_graph_hunyuanocr::build(). + GGML_ASSERT(model.position_embeddings != nullptr); + ggml_tensor * src_t = model.position_embeddings; + const int64_t n_embd = src_t->ne[0]; + const int64_t n_pos = src_t->ne[1]; // = n_grid * n_grid + const int n_grid = (int)std::lround(std::sqrt((double)n_pos)); + GGML_ASSERT((int64_t)n_grid * n_grid == n_pos); + const int out_w = pos_w; // pw + const int out_h = pos_h; // ph + + // Pull weight to host. + std::vector src(n_embd * n_pos); + ggml_backend_tensor_get(src_t, src.data(), 0, ggml_nbytes(src_t)); + + // Output layout matches ggml_new_tensor_2d(F32, n_embd, out_h*out_w): + // ne[0] = n_embd (fastest), ne[1] = out_h*out_w + // dst[(y*out_w + x) * n_embd + c] + std::vector dst((size_t)n_embd * out_h * out_w); + + const float sx = (float)(out_w + 0.1f) / (float)n_grid; + const float sy = (float)(out_h + 0.1f) / (float)n_grid; + + for (int y = 0; y < out_h; ++y) { + // Match ggml_compute_forward_upscale_f32 pixel-center + // convention (align_corners=False): src_y = (y+0.5)/sy - 0.5. + const float fy = ((float)y + 0.5f) / sy - 0.5f; + int y0 = (int)std::floor(fy); + int y1 = y0 + 1; + y0 = std::clamp(y0, 0, n_grid - 1); + y1 = std::clamp(y1, 0, n_grid - 1); + float wy1 = std::clamp(fy - (float)y0, 0.0f, 1.0f); + const float wy0 = 1.0f - wy1; + for (int x = 0; x < out_w; ++x) { + const float fx = ((float)x + 0.5f) / sx - 0.5f; + int x0 = (int)std::floor(fx); + int x1 = x0 + 1; + x0 = std::clamp(x0, 0, n_grid - 1); + x1 = std::clamp(x1, 0, n_grid - 1); + float wx1 = std::clamp(fx - (float)x0, 0.0f, 1.0f); + const float wx0 = 1.0f - wx1; + + const float w00 = wy0 * wx0; + const float w01 = wy0 * wx1; + const float w10 = wy1 * wx0; + const float w11 = wy1 * wx1; + + const float * s00 = &src[((size_t)y0 * n_grid + x0) * n_embd]; + const float * s01 = &src[((size_t)y0 * n_grid + x1) * n_embd]; + const float * s10 = &src[((size_t)y1 * n_grid + x0) * n_embd]; + const float * s11 = &src[((size_t)y1 * n_grid + x1) * n_embd]; + float * d = &dst[((size_t)y * out_w + x) * n_embd]; + for (int c = 0; c < n_embd; ++c) { + d[c] = w00 * s00[c] + w01 * s01[c] + w10 * s10[c] + w11 * s11[c]; + } + } + } + + set_input_f32("hunyuanvl_pos_embd", dst); + } break; case PROJECTOR_TYPE_LLAMA4: { // set the 2D positions @@ -3689,8 +3836,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_KIMIK25: + case PROJECTOR_TYPE_YASA2: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; diff --git a/tools/mtmd/models/hunyuanocr.cpp b/tools/mtmd/models/hunyuanocr.cpp index 37d1e2b86a9..45ed684f70d 100644 --- a/tools/mtmd/models/hunyuanocr.cpp +++ b/tools/mtmd/models/hunyuanocr.cpp @@ -5,7 +5,21 @@ ggml_cgraph * clip_graph_hunyuanocr::build() { const int pw = n_patches_x; const int ph = n_patches_y; - ggml_tensor * pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR); + // Position embedding interpolation. + // HunyuanVL needs scale factors sf=(target+0.1)/n_grid, which the standard + // ggml_interpolate cannot express. To avoid adding a new ggml op, the + // resize is computed on CPU in clip_image_batch_encode and uploaded here + // as a graph input (named "hunyuanvl_pos_embd"). + // HunyuanOCR uses the same square layout and the standard ratio-based + // interpolation provided by resize_position_embeddings(). + ggml_tensor * pos_embd = nullptr; + if (proj_type == PROJECTOR_TYPE_HUNYUANVL && model.position_embeddings) { + pos_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ph * pw); + ggml_set_name(pos_embd, "hunyuanvl_pos_embd"); + ggml_set_input(pos_embd); + } else { + pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BILINEAR); + } ggml_tensor * inp = build_inp(); ggml_tensor * cur = build_vit(inp, n_patches, NORM_TYPE_NORMAL, hparams.ffn_op, pos_embd, nullptr); diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 03d99e15b05..c30d79133ef 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -43,6 +43,14 @@ struct clip_graph_youtuvl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_yasa2 : clip_graph { + clip_graph_yasa2(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps = 1e-6f); + ggml_tensor * convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b); +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/yasa2.cpp b/tools/mtmd/models/yasa2.cpp new file mode 100644 index 00000000000..e8cd3dacbf5 --- /dev/null +++ b/tools/mtmd/models/yasa2.cpp @@ -0,0 +1,191 @@ +// ABOUTME: Yasa2 vision encoder graph builder for ConvNeXt-based architecture. +// ABOUTME: Implements patch embedding, ConvNeXt stages with GRN, and adaptive pooling. + +#include "models.h" + +static ggml_tensor * add_channel_bias( + ggml_context * ctx0, + ggml_tensor * x_whcb, + ggml_tensor * b_c) { + if (!b_c) { + return x_whcb; + } + ggml_tensor * b4 = ggml_reshape_4d(ctx0, b_c, 1, 1, b_c->ne[0], 1); + return ggml_add(ctx0, x_whcb, b4); +} + +static ggml_tensor * mul_channel_weight( + ggml_context * ctx0, + ggml_tensor * x_whcb, + ggml_tensor * w_c) { + if (!w_c) { + return x_whcb; + } + ggml_tensor * w4 = ggml_reshape_4d(ctx0, w_c, 1, 1, w_c->ne[0], 1); + return ggml_mul(ctx0, x_whcb, w4); +} + +ggml_tensor * clip_graph_yasa2::layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps) { + // Match HF ConvNextLayerNorm(channels_first): + // u = mean_c(x), s = mean_c((x-u)^2), x = (x-u)/sqrt(s+eps) + // cast back to input dtype before affine. + ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); // [W,H,C,B] -> [C,H,W,B] + cur = ggml_cont(ctx0, cur); + + ggml_tensor * u = ggml_mean(ctx0, cur); // [1,H,W,B] + ggml_tensor * xm = ggml_sub(ctx0, cur, u); // [C,H,W,B] + + ggml_tensor * s = ggml_mul(ctx0, xm, xm); // [C,H,W,B] + s = ggml_mean(ctx0, s); // [1,H,W,B] + s = ggml_clamp(ctx0, s, eps, 1e30f); // avoid div-by-zero in no-alloc warmup + s = ggml_sqrt(ctx0, s); // [1,H,W,B] + + ggml_tensor * xhat = ggml_div(ctx0, xm, s); // [C,H,W,B] + xhat = ggml_permute(ctx0, xhat, 2, 1, 0, 3); // [W,H,C,B] + xhat = ggml_cont(ctx0, xhat); + xhat = mul_channel_weight(ctx0, xhat, w); + xhat = add_channel_bias(ctx0, xhat, b); + return xhat; +} + +ggml_tensor * clip_graph_yasa2::convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b) { + // Exact ConvNeXtV2 GRN: + // Gx = ||x||_2 over spatial dims (W,H), Nx = Gx / (mean_c(Gx) + eps) + // y = w * (x * Nx) + b + x + const int64_t wdim = inp->ne[0]; + const int64_t hdim = inp->ne[1]; + const int64_t cdim = inp->ne[2]; + const int64_t bdim = inp->ne[3]; + + // Keep GRN math in fp32 for stability; fp16/bf16 accumulation can drift. + ggml_tensor * sq = ggml_mul(ctx0, inp, inp); + ggml_tensor * sq_flat = ggml_reshape_4d(ctx0, sq, wdim * hdim, cdim, 1, bdim); // [WH,C,1,B] + ggml_tensor * gx = ggml_sum_rows(ctx0, sq_flat); // [1,C,1,B] + gx = ggml_sqrt(ctx0, gx); // [1,C,1,B] + + ggml_tensor * gx_ch_first = ggml_permute(ctx0, gx, 1, 0, 2, 3); // [C,1,1,B] + gx_ch_first = ggml_cont(ctx0, gx_ch_first); + ggml_tensor * gx_mean = ggml_mean(ctx0, gx_ch_first); // [1,1,1,B] + + gx_mean = ggml_clamp(ctx0, gx_mean, 1e-6f, 1e30f); // approx +eps, warmup-safe + ggml_tensor * nx = ggml_div(ctx0, gx, gx_mean); // [1,C,1,B] + nx = ggml_permute(ctx0, nx, 0, 2, 1, 3); // [1,1,C,B] + nx = ggml_cont(ctx0, nx); + + ggml_tensor * xnx = ggml_mul(ctx0, inp, nx); + xnx = mul_channel_weight(ctx0, xnx, w); + xnx = add_channel_bias(ctx0, xnx, b); + return ggml_add(ctx0, inp, xnx); +} + +ggml_cgraph * clip_graph_yasa2::build() { + ggml_tensor * cur = build_inp_raw(); + + // Patch embedding Conv2d(kernel=4, stride=4) + cur = ggml_conv_2d(ctx0, model.yasa_patch_w, cur, patch_size, patch_size, 0, 0, 1, 1); + cur = add_channel_bias(ctx0, cur, model.yasa_patch_b); + ggml_set_name(cur, "yasa2_patch_conv_out"); + cb(cur, "yasa2_patch_conv_out", -1); + cur = layer_norm_channels(cur, model.yasa_patch_ln_w, model.yasa_patch_ln_b, eps); + ggml_set_name(cur, "yasa2_patch_ln_out"); + cb(cur, "yasa2_patch_ln_out", -1); + + // ConvNeXt stages + for (size_t s = 0; s < model.yasa_stages.size(); ++s) { + const auto & stage = model.yasa_stages[s]; + + if (stage.down_conv_w) { + cur = layer_norm_channels(cur, stage.down_ln_w, stage.down_ln_b, eps); + cur = ggml_conv_2d(ctx0, stage.down_conv_w, cur, 2, 2, 0, 0, 1, 1); + cur = add_channel_bias(ctx0, cur, stage.down_conv_b); + ggml_format_name(cur, "yasa2_stage%zu_down_out", s); + } + + for (size_t bi = 0; bi < stage.blocks.size(); ++bi) { + const auto & blk = stage.blocks[bi]; + ggml_tensor * res = cur; + + ggml_tensor * x = ggml_conv_2d_dw(ctx0, blk.dw_w, cur, 1, 1, 3, 3, 1, 1); + x = add_channel_bias(ctx0, x, blk.dw_b); + x = layer_norm_channels(x, blk.ln_w, blk.ln_b, eps); + + // pwconv1/pwconv2 are HF Linear layers over channels; implement via matmul on tokens. + const int64_t w = x->ne[0]; + const int64_t h = x->ne[1]; + const int64_t b = x->ne[3]; + + ggml_tensor * tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,C,B] + tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [C,T,B] + tok = ggml_cont(ctx0, tok); + + tok = ggml_mul_mat(ctx0, blk.pw1_w, tok); // [4C,T,B] + if (blk.pw1_b) { + ggml_tensor * b1 = ggml_reshape_3d(ctx0, blk.pw1_b, blk.pw1_b->ne[0], 1, 1); // [4C,1,1] + tok = ggml_add(ctx0, tok, b1); + } + x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,4C,B] + x = ggml_cont(ctx0, x); + x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,4C,B] + x = ggml_gelu_erf(ctx0, x); + x = convnext_grn(x, blk.grn_w, blk.grn_b); + + tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,4C,B] + tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [4C,T,B] + tok = ggml_cont(ctx0, tok); + + tok = ggml_mul_mat(ctx0, blk.pw2_w, tok); // [C,T,B] + if (blk.pw2_b) { + ggml_tensor * b2 = ggml_reshape_3d(ctx0, blk.pw2_b, blk.pw2_b->ne[0], 1, 1); // [C,1,1] + tok = ggml_add(ctx0, tok, b2); + } + x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,C,B] + x = ggml_cont(ctx0, x); + x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,C,B] + + cur = ggml_add(ctx0, res, x); + ggml_format_name(cur, "yasa2_stage%zu_blk%zu_out", s, bi); + } + } + + // HF path adds vision position embeddings BEFORE adaptive pooling. + const int64_t pre_w = cur->ne[0]; + const int64_t pre_h = cur->ne[1]; + ggml_tensor * tokens_pre = ggml_reshape_3d(ctx0, cur, pre_w * pre_h, cur->ne[2], cur->ne[3]); // [T,C,B] + tokens_pre = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [C,T,B] + tokens_pre = ggml_cont(ctx0, tokens_pre); + if (model.yasa_vision_pos_embed && tokens_pre->ne[1] == model.yasa_vision_pos_embed->ne[1]) { + const int64_t n_ch = model.yasa_vision_pos_embed->ne[0]; + const int64_t n_tokens = model.yasa_vision_pos_embed->ne[1]; + ggml_tensor * pos = ggml_reshape_3d(ctx0, model.yasa_vision_pos_embed, (int) n_ch, (int) n_tokens, 1); + tokens_pre = ggml_add(ctx0, tokens_pre, pos); + } + cur = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [T,C,B] + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_4d(ctx0, cur, pre_w, pre_h, cur->ne[1], cur->ne[2]); // [W,H,C,B] + + // AdaptiveAvgPool2d target is 8x8 for real inputs, but warmup can use tiny images. + const int pooled_w = std::min(8, (int) cur->ne[0]); + const int pooled_h = std::min(8, (int) cur->ne[1]); + const int kw = std::max(1, (int) cur->ne[0] / pooled_w); + const int kh = std::max(1, (int) cur->ne[1] / pooled_h); + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kw, kh, kw, kh, 0, 0); + + // [W,H,C,B] -> [C,T,B] + ggml_tensor * tokens = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2], cur->ne[3]); + tokens = ggml_permute(ctx0, tokens, 1, 0, 2, 3); + tokens = ggml_cont(ctx0, tokens); + cb(tokens, "yasa2_tokens", -1); + + GGML_ASSERT(model.mm_0_w && model.mm_2_w); + ggml_tensor * embeddings = build_ffn( + tokens, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + cb(embeddings, "yasa2_emb", -1); + + ggml_build_forward_expand(gf, embeddings); + return gf; +} diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 145b88cea44..40940741637 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -114,10 +114,10 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { return n_pos; } -void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) { +void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, llama_pos pos_0, mtmd_decoder_pos * out_pos) { size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks); for (size_t i = 0; i < n_tokens; i++) { - out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, i); + out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, pos_0, i); } } @@ -163,15 +163,15 @@ struct decode_embd_batch { } // M-RoPE for image - void set_position_mrope_2d(llama_pos pos_0, const std::vector & rel_pos, llama_seq_id seq_id) { + void set_position_mrope_2d(const std::vector & rel_pos, llama_seq_id seq_id) { GGML_ASSERT(n_pos_per_embd == 4); GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens); seq_id_0[0] = seq_id; for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i ] = pos_0 + rel_pos[i].t; - pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y; - pos[i + batch.n_tokens * 2] = pos_0 + rel_pos[i].x; - pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused + pos[i ] = rel_pos[i].t; + pos[i + batch.n_tokens ] = rel_pos[i].y; + pos[i + batch.n_tokens * 2] = rel_pos[i].x; + pos[i + batch.n_tokens * 3] = rel_pos[i].z; } for (int i = 0; i < batch.n_tokens; i++) { batch.n_seq_id[i] = 1; @@ -188,7 +188,7 @@ struct decode_embd_batch { pos[i ] = pos_0 + i; pos[i + batch.n_tokens ] = pos_0 + i; pos[i + batch.n_tokens * 2] = pos_0 + i; - pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused + pos[i + batch.n_tokens * 3] = pos_0 + i; } for (int i = 0; i < batch.n_tokens; i++) { batch.n_seq_id[i] = 1; @@ -268,8 +268,8 @@ int32_t mtmd_helper_decode_image_chunk( } const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); std::vector rel_pos(n_tokens); - mtmd_helper_image_get_decoder_pos(image_tokens, rel_pos.data()); - batch_embd.set_position_mrope_2d(n_past, rel_pos, seq_id); + mtmd_helper_image_get_decoder_pos(image_tokens, n_past, rel_pos.data()); + batch_embd.set_position_mrope_2d(rel_pos, seq_id); } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { batch_embd.set_position_mrope_1d(n_past, seq_id); } else { diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index ff34a412141..57da78a754f 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -49,7 +49,7 @@ MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks); // helper to get the list of relative positions corresponding to the embedding tokens, to be used by M-RoPE // out_pos must have length == mtmd_helper_get_n_tokens(image) -MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, struct mtmd_decoder_pos * out_pos); +MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, llama_pos pos_0, struct mtmd_decoder_pos * out_pos); // helper function that automatically: // 1. run llama_decode() on text chunks diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d0a0a4865ef..59907786786 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -33,11 +33,25 @@ struct mtmd_bitmap { bool is_audio = false; // true if the bitmap is audio }; +// position indexing for decoder model +enum mtmd_pos_type { + MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens + MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes + MTMD_POS_TYPE_HUNYUANVL, // HunyuanVL mrope + BOI/EOI/newline layout with XD-RoPE dim-3 +}; + struct mtmd_image_tokens { uint32_t nx; // number of tokens in x direction uint32_t ny; // number of tokens in y direction - bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position) - uint32_t n_tokens() const { return nx * ny; } + mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL; + uint32_t image_idx = 0; // 0-based position of this image among image chunks in the prompt(used by pos == MTMD_POS_TYPE_HUNYUANVL) + uint32_t n_tokens() const { + if (pos == MTMD_POS_TYPE_HUNYUANVL) { + // [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI] + return (nx + 1) * ny + 2; + } + return nx * ny; + } clip_image_f32_batch batch_f32; // preprocessed image patches std::string id; // optional user-defined ID, useful for KV cache tracking @@ -45,7 +59,8 @@ struct mtmd_image_tokens { return mtmd_image_tokens{ nx, ny, - use_mrope_pos, + pos, + image_idx, batch_f32.clone(), id }; @@ -131,6 +146,7 @@ struct mtmd_context { int n_threads; std::string media_marker; const int n_embd_text; + mtmd_pos_type pos_type; // these are not token, but strings used to mark the beginning and end of image/audio embeddings std::string img_beg; @@ -177,6 +193,23 @@ struct mtmd_context { throw std::runtime_error("media_marker must not be empty"); } + auto decoder_rope_type = llama_model_rope_type(text_model); + switch (decoder_rope_type) { + case LLAMA_ROPE_TYPE_NONE: + case LLAMA_ROPE_TYPE_NORM: + case LLAMA_ROPE_TYPE_NEOX: + { + pos_type = MTMD_POS_TYPE_NORMAL; + } break; + case LLAMA_ROPE_TYPE_MROPE: + case LLAMA_ROPE_TYPE_IMROPE: + { + pos_type = MTMD_POS_TYPE_MROPE; + } break; + default: + throw std::runtime_error(string_format("unsupported decoder rope type: %d\n", decoder_rope_type)); + } + clip_context_params ctx_clip_params { /* use_gpu */ ctx_params.use_gpu, /* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type), @@ -293,6 +326,19 @@ struct mtmd_context { img_end = "<|vision_end|>"; image_preproc = std::make_unique(ctx_v); } break; + case PROJECTOR_TYPE_YASA2: + { + img_beg = ""; + img_end = ""; + // Currently only supprots single-tile preprocessing: any input is downscaled + // to one image_size x image_size tile (64 output tokens via 8x8 adaptive avg + // pool). + // However, the model itself supports llava-uhd multi-tile tiling for high-res + // images. This will be implemented in a future PR (dispatch on has_pinpoints + // - see LDP/COGVLM branch above) and emit image_grid_pinpoints in the conversion + // script. + image_preproc = std::make_unique(ctx_v); + } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3NV: { @@ -430,6 +476,7 @@ struct mtmd_context { image_preproc = std::make_unique(ctx_v); } break; case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_HUNYUANVL: { // note: these use fullwidth | (U+FF5C) and ▁ (U+2581) to match the tokenizer vocabulary img_beg = "<|hy_place▁holder▁no▁100|>"; @@ -575,6 +622,7 @@ struct mtmd_tokenizer { const llama_vocab * vocab; mtmd_input_chunks cur; + uint32_t n_images_added = 0; // 0-based index assigned to the next image chunk mtmd_tokenizer(mtmd_context * ctx, const mtmd_input_text * text, @@ -777,12 +825,20 @@ struct mtmd_tokenizer { // for Qwen2VL, we need this information for M-RoPE decoding positions image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get()); image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get()); - image_tokens->use_mrope_pos = true; } else { // other models, we only need the total number of tokens image_tokens->nx = n_tokens; image_tokens->ny = 1; } + image_tokens->pos = ctx->pos_type; + // HunyuanVL wraps the image grid with BOI/EOI and adds one newline per row, + // and uses XD-RoPE (dim-3 = image index). Override the position type so that + // n_tokens() and mtmd_image_tokens_get_decoder_pos pick the HunyuanVL layout. + if (ctx->proj_type_v() == PROJECTOR_TYPE_HUNYUANVL) { + image_tokens->pos = MTMD_POS_TYPE_HUNYUANVL; + image_tokens->image_idx = n_images_added; + GGML_ASSERT(n_tokens == (size_t)image_tokens->n_tokens()); + } image_tokens->batch_f32 = std::move(batch_f32); image_tokens->id = bitmap->id; // optional @@ -803,6 +859,9 @@ struct mtmd_tokenizer { add_text(ctx->img_end, true); // add image end token } + // advance image-chunk counter so the next image gets the next XD-RoPE dim-3 slot + n_images_added++; + } else { // handle audio @@ -1014,7 +1073,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) { +bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk) { auto proj_type = ctx->proj_type_v(); if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { proj_type = ctx->proj_type_a(); @@ -1028,32 +1087,19 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chu } } -bool mtmd_decode_use_mrope(mtmd_context * ctx) { - if (ctx->ctx_v == nullptr && ctx->proj_type_a() == PROJECTOR_TYPE_QWEN3A) { - // qwen3-asr - return true; - } - switch (ctx->proj_type_v()) { - case PROJECTOR_TYPE_QWEN2VL: - case PROJECTOR_TYPE_QWEN25VL: - case PROJECTOR_TYPE_QWEN3VL: - case PROJECTOR_TYPE_GLM4V: - case PROJECTOR_TYPE_PADDLEOCR: - return true; - default: - return false; - } +bool mtmd_decode_use_mrope(const mtmd_context * ctx) { + return ctx->pos_type == MTMD_POS_TYPE_MROPE; } -bool mtmd_support_vision(mtmd_context * ctx) { +bool mtmd_support_vision(const mtmd_context * ctx) { return ctx->ctx_v != nullptr; } -bool mtmd_support_audio(mtmd_context * ctx) { +bool mtmd_support_audio(const mtmd_context * ctx) { return ctx->ctx_a != nullptr; } -int mtmd_get_audio_sample_rate(mtmd_context * ctx) { +int mtmd_get_audio_sample_rate(const mtmd_context * ctx) { if (!ctx->ctx_a) { return -1; } @@ -1246,11 +1292,58 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { return image_tokens->ny; } -mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i) { +mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i) { mtmd_decoder_pos pos; - pos.t = 0; - pos.x = i % image_tokens->nx; - pos.y = i / image_tokens->nx; + switch (image_tokens->pos) { + case MTMD_POS_TYPE_MROPE: + { + pos.t = pos_0; + pos.x = pos_0 + (i % image_tokens->nx); + pos.y = pos_0 + (i / image_tokens->nx); + pos.z = 0; // unused for now + } break; + case MTMD_POS_TYPE_NORMAL: + { + pos.t = pos_0 + i; + pos.x = pos_0 + i; + pos.y = pos_0 + i; + pos.z = pos_0 + i; + } break; + case MTMD_POS_TYPE_HUNYUANVL: + { + // HunyuanVL layout: [BOI] [row0 tokens + newline] ... [row(ny-1) tokens + newline] [EOI] + // Total = 1 + ny*(nx+1) + 1. BOI and EOI use sequential positions in every dim; + // content and row-newline tokens use (row, col) with XD-RoPE dim-3 = image_idx. + const uint32_t nx = image_tokens->nx; + const uint32_t n_total = image_tokens->n_tokens(); + if (i == 0) { + // BOI + pos.t = pos_0 + i; + pos.x = pos_0 + i; + pos.y = pos_0 + i; + pos.z = pos_0 + i; + } else if (i == n_total - 1) { + // EOI + pos.t = pos_0 + i; + pos.x = pos_0 + i; + pos.y = pos_0 + i; + pos.z = pos_0 + i; + } else { + // content token at (row, col), or the trailing newline of a row (col == nx) + // section 0 = sequential, section 1 = w(col), section 2 = h(row), section 3 = image_count. + // set_position_mrope_2d writes .y -> section 1 and .x -> section 2 + const uint32_t offset = (uint32_t)i - 1; + const uint32_t row = offset / (nx + 1); + const uint32_t col = offset % (nx + 1); + pos.t = pos_0 + i; + pos.x = row; + pos.y = col; + pos.z = image_tokens->image_idx; + } + } break; + default: + GGML_ABORT("invalid position type"); + } return pos; } @@ -1259,12 +1352,18 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { } llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { - if (image_tokens->use_mrope_pos) { - // for M-RoPE, temporal dimension = max(t,h,w) - // t is omitted as we don't support video input - return std::max(image_tokens->nx, image_tokens->ny); + switch (image_tokens->pos) { + case MTMD_POS_TYPE_MROPE: + return std::max(image_tokens->nx, image_tokens->ny); + case MTMD_POS_TYPE_NORMAL: + return image_tokens->n_tokens(); + case MTMD_POS_TYPE_HUNYUANVL: + // HunyuanVL: the sequential (dim-0) position advances by the full token count + // (includes BOI/EOI and row newline tokens), not by max(nx, ny) + return image_tokens->n_tokens(); + default: + GGML_ABORT("invalid position type"); } - return image_tokens->n_tokens(); } // test function diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index a6fd8efa5d0..e364174b820 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -112,20 +112,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // whether we need to set non-causal mask before llama_decode // if chunk is nullptr, we assume the default case where chunk is an image chunk -MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk); +MTMD_API bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk); // whether the current model use M-RoPE for llama_decode -MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); +MTMD_API bool mtmd_decode_use_mrope(const mtmd_context * ctx); // whether the current model supports vision input -MTMD_API bool mtmd_support_vision(mtmd_context * ctx); +MTMD_API bool mtmd_support_vision(const mtmd_context * ctx); // whether the current model supports audio input -MTMD_API bool mtmd_support_audio(mtmd_context * ctx); +MTMD_API bool mtmd_support_audio(const mtmd_context * ctx); // get audio sample rate in Hz, for example 16000 for Whisper // return -1 if audio is not supported -MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx); +MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx); // mtmd_bitmap // @@ -196,11 +196,13 @@ struct mtmd_decoder_pos { uint32_t t; uint32_t x; uint32_t y; + uint32_t z; // unused for now, reserved for future use }; // get position for decoder attention, to be used by M-RoPE models // i is the index of the embedding token, ranging from 0 to mtmd_image_tokens_get_n_tokens() - 1 +// pos_0 is the absolute position of the first token // return relative position (for example, embedding 0 will have position (0, 0, 0); remember to adjust it to the current absolute position) -MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i); +MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i); // tokenize an input text prompt and a list of bitmaps (images/audio) // the prompt must have the input image marker (default: "<__media__>") in it diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 5da48d61bfd..83416fb272b 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -91,6 +91,7 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0" add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR" add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR" +add_test_vision "ggml-org/HunyuanVL-4B-GGUF:Q8_0" add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index 6e319ce55d4..75defd7c87b 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -1,5 +1,6 @@ #include "arg.h" #include "common.h" +#include "fit.h" #include "log.h" #include "llama.h" @@ -2087,7 +2088,7 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_memory_breakdown_print(ctx); + common_memory_breakdown_print(ctx); llama_backend_free(); diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 0cce99f5968..71cc0e7a8c2 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -5,6 +5,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) set(TARGET server-context) add_library(${TARGET} STATIC + server-chat.cpp + server-chat.h server-task.cpp server-task.h server-queue.cpp diff --git a/tools/server/README.md b/tools/server/README.md index b30309bf3b0..db1f2703904 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -167,7 +167,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `-cpent, --checkpoint-every-n-tokens N` | create a checkpoint every n tokens during prefill (processing), -1 to disable (default: 8192)
(env: LLAMA_ARG_CHECKPOINT_EVERY_NT) | | `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `-kvu, --kv-unified, -no-kvu, --no-kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)
(env: LLAMA_ARG_KV_UNIFIED) | -| `--clear-idle, --no-clear-idle` | save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)
(env: LLAMA_ARG_CLEAR_IDLE) | +| `--cache-idle-slots, --no-cache-idle-slots` | save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)
(env: LLAMA_ARG_CACHE_IDLE_SLOTS) | | `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | | `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode | | `-sp, --special` | special tokens output enabled (default: false) | @@ -806,6 +806,7 @@ By default, it is read-only. To make POST request to change global properties, y "modalities": { "vision": false }, + "media_marker": "<__media_YoNhud46VdDqbuFmKYEO9PY7A4ARzRfg__>", "build_info": "b(build number)-(build commit hash)", "is_sleeping": false } diff --git a/tools/server/server-chat.cpp b/tools/server/server-chat.cpp new file mode 100644 index 00000000000..a1558346944 --- /dev/null +++ b/tools/server/server-chat.cpp @@ -0,0 +1,630 @@ +#include "server-chat.h" +#include "server-common.h" + +#include + +json server_chat_convert_responses_to_chatcmpl(const json & response_body) { + if (!response_body.contains("input")) { + throw std::invalid_argument("'input' is required"); + } + if (!json_value(response_body, "previous_response_id", std::string{}).empty()) { + throw std::invalid_argument("llama.cpp does not support 'previous_response_id'."); + } + + const json input_value = response_body.at("input"); + json chatcmpl_body = response_body; + chatcmpl_body.erase("input"); + std::vector chatcmpl_messages; + + if (response_body.contains("instructions")) { + chatcmpl_messages.push_back({ + {"role", "system"}, + {"content", json_value(response_body, "instructions", std::string())}, + }); + chatcmpl_body.erase("instructions"); + } + + if (input_value.is_string()) { + // #responses_create-input-text_input + chatcmpl_messages.push_back({ + {"role", "user"}, + {"content", input_value}, + }); + } else if (input_value.is_array()) { + // #responses_create-input-input_item_list + + static auto exists_and_is_array = [](const json & j, const char * key) -> bool { + return j.contains(key) && j.at(key).is_array(); + }; + static auto exists_and_is_string = [](const json & j, const char * key) -> bool { + return j.contains(key) && j.at(key).is_string(); + }; + + for (json item : input_value) { + bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant"; + + if (exists_and_is_string(item, "content")) { + // #responses_create-input-input_item_list-input_message-content-text_input + // Only "Input message" contains item["content"]::string + // After converting item["content"]::string to item["content"]::array, + // we can treat "Input message" as sum of "Item-Input message" and "Item-Output message" + item["content"] = json::array({ + json { + {"text", item.at("content")}, + {"type", "input_text"} + } + }); + } + + if (exists_and_is_array(item, "content") && + exists_and_is_string(item, "role") && + (item.at("role") == "user" || + item.at("role") == "system" || + item.at("role") == "developer") + ) { + // #responses_create-input-input_item_list-item-input_message + std::vector chatcmpl_content; + + for (const json & input_item : item.at("content")) { + const std::string type = json_value(input_item, "type", std::string()); + + if (type == "input_text") { + if (!input_item.contains("text")) { + throw std::invalid_argument("'Input text' requires 'text'"); + } + chatcmpl_content.push_back({ + {"text", input_item.at("text")}, + {"type", "text"}, + }); + } else if (type == "input_image") { + // While `detail` is marked as required, + // it has default value("auto") and can be omitted. + + if (!input_item.contains("image_url")) { + throw std::invalid_argument("'image_url' is required"); + } + chatcmpl_content.push_back({ + {"image_url", json { + {"url", input_item.at("image_url")} + }}, + {"type", "image_url"}, + }); + } else if (type == "input_file") { + throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment"); + } else { + throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'"); + } + } + + if (item.contains("type")) { + item.erase("type"); + } + if (item.contains("status")) { + item.erase("status"); + } + item["content"] = chatcmpl_content; + + chatcmpl_messages.push_back(item); + } else if (exists_and_is_string(item, "role") && + item.at("role") == "assistant" && + exists_and_is_string(item, "type") && + item.at("type") == "message" + ) { + // #responses_create-input-input_item_list-item-output_message + auto chatcmpl_content = json::array(); + + // Handle both string content and array content + if (item.contains("content") && item.at("content").is_string()) { + // String content - convert to text content part + chatcmpl_content.push_back({ + {"text", item.at("content")}, + {"type", "text"}, + }); + } else if (exists_and_is_array(item, "content")) { + // Array content - process each item + for (const auto & output_text : item.at("content")) { + const std::string type = json_value(output_text, "type", std::string()); + if (type == "output_text" || type == "input_text") { + // Accept both output_text and input_text (string content gets converted to input_text) + if (!exists_and_is_string(output_text, "text")) { + throw std::invalid_argument("'Output text' requires 'text'"); + } + chatcmpl_content.push_back({ + {"text", output_text.at("text")}, + {"type", "text"}, + }); + } else if (type == "refusal") { + if (!exists_and_is_string(output_text, "refusal")) { + throw std::invalid_argument("'Refusal' requires 'refusal'"); + } + chatcmpl_content.push_back({ + {"refusal", output_text.at("refusal")}, + {"type", "refusal"}, + }); + } else { + throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'"); + } + } + } + + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + if (!exists_and_is_array(prev_msg, "content")) { + prev_msg["content"] = json::array(); + } + auto & prev_content = prev_msg["content"]; + prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end()); + } else { + item.erase("status"); + item.erase("type"); + item["content"] = chatcmpl_content; + chatcmpl_messages.push_back(item); + } + } else if (exists_and_is_string(item, "arguments") && + exists_and_is_string(item, "call_id") && + exists_and_is_string(item, "name") && + exists_and_is_string(item, "type") && + item.at("type") == "function_call" + ) { + // #responses_create-input-input_item_list-item-function_tool_call + json tool_call = { + {"function", json { + {"arguments", item.at("arguments")}, + {"name", item.at("name")}, + }}, + {"id", item.at("call_id")}, + {"type", "function"}, + }; + + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + if (!exists_and_is_array(prev_msg, "tool_calls")) { + prev_msg["tool_calls"] = json::array(); + } + prev_msg["tool_calls"].push_back(tool_call); + } else { + chatcmpl_messages.push_back(json { + {"role", "assistant"}, + {"tool_calls", json::array({tool_call})} + }); + } + } else if (exists_and_is_string(item, "call_id") && + (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && + exists_and_is_string(item, "type") && + item.at("type") == "function_call_output" + ) { + // #responses_create-input-input_item_list-item-function_tool_call_output + if (item.at("output").is_string()) { + chatcmpl_messages.push_back(json { + {"content", item.at("output")}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } else { + json chatcmpl_outputs = item.at("output"); + for (json & chatcmpl_output : chatcmpl_outputs) { + if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") { + throw std::invalid_argument("Output of tool call should be 'Input text'"); + } + chatcmpl_output["type"] = "text"; + } + chatcmpl_messages.push_back(json { + {"content", chatcmpl_outputs}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } + } else if (exists_and_is_array(item, "summary") && + exists_and_is_string(item, "type") && + item.at("type") == "reasoning") { + // #responses_create-input-input_item_list-item-reasoning + + if (!exists_and_is_array(item, "content")) { + throw std::invalid_argument("item['content'] is not an array"); + } + if (item.at("content").empty()) { + throw std::invalid_argument("item['content'] is empty"); + } + if (!exists_and_is_string(item.at("content")[0], "text")) { + throw std::invalid_argument("item['content']['text'] is not a string"); + } + + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + prev_msg["reasoning_content"] = item.at("content")[0].at("text"); + } else { + chatcmpl_messages.push_back(json { + {"role", "assistant"}, + {"content", json::array()}, + {"reasoning_content", item.at("content")[0].at("text")}, + }); + } + } else { + throw std::invalid_argument("Cannot determine type of 'item'"); + } + } + } else { + throw std::invalid_argument("'input' must be a string or array of objects"); + } + + chatcmpl_body["messages"] = chatcmpl_messages; + + if (response_body.contains("tools")) { + if (!response_body.at("tools").is_array()) { + throw std::invalid_argument("'tools' must be an array of objects"); + } + std::vector chatcmpl_tools; + for (json resp_tool : response_body.at("tools")) { + json chatcmpl_tool; + + if (json_value(resp_tool, "type", std::string()) != "function") { + throw std::invalid_argument("'type' of tool must be 'function'"); + } + resp_tool.erase("type"); + chatcmpl_tool["type"] = "function"; + + if (!resp_tool.contains("strict")) { + resp_tool["strict"] = true; + } + chatcmpl_tool["function"] = resp_tool; + chatcmpl_tools.push_back(chatcmpl_tool); + } + chatcmpl_body.erase("tools"); + chatcmpl_body["tools"] = chatcmpl_tools; + } + + if (response_body.contains("max_output_tokens")) { + chatcmpl_body.erase("max_output_tokens"); + chatcmpl_body["max_tokens"] = response_body["max_output_tokens"]; + } + + return chatcmpl_body; +} + +// Edits the cch section of an "x-anthropic-billing-header" system prompt. +// Does nothing to any other prompt. +// +// This is a claude message with a "cch=ef01a" attribute that breaks prefix caching. +// The cch stamp is a whitebox end-to-end integrity hint. It's not meaningful as a +// system prompt data, particularly to llama.cpp, but its presence means the prefix +// cache will not get past it: It changes on each request. +// +// Reference: https://github.com/ggml-org/llama.cpp/pull/21793 +// Example header: +// ``` +// x-anthropic-billing-header: cc_version=2.1.101.e51; cc_entrypoint=cli; cch=a5145;You are Claude Code, Anthropic's official CLI for Claude. +// ^^^^^ +// ``` +static void normalize_anthropic_billing_header(std::string & system_text) { + if (system_text.rfind("x-anthropic-billing-header:", 0) != 0) { + return; + } + + const size_t header_prefix_length = strlen("x-anthropic-billing-header:"); + const size_t cch_length = 5; + const size_t index_cch = system_text.find("cch=", header_prefix_length); + if (index_cch == std::string::npos) { + return; + } + + const size_t index_replace = index_cch + 4; + if (index_replace + cch_length < system_text.length() && system_text[index_replace + cch_length] == ';') { + for (size_t i = 0; i < cch_length; ++i) { + system_text[index_replace + i] = 'f'; + } + } else { + LOG_ERR("anthropic string not as expected: %s", system_text.c_str()); + } +} + +json server_chat_convert_anthropic_to_oai(const json & body) { + json oai_body; + + // Convert system prompt + json oai_messages = json::array(); + auto system_param = json_value(body, "system", json()); + if (!system_param.is_null()) { + std::string system_content; + + if (system_param.is_string()) { + system_content = system_param.get(); + normalize_anthropic_billing_header(system_content); + } else if (system_param.is_array()) { + for (const auto & block : system_param) { + if (json_value(block, "type", std::string()) == "text") { + auto system_text = json_value(block, "text", std::string()); + normalize_anthropic_billing_header(system_text); + system_content += system_text; + } + } + } + + oai_messages.push_back({ + {"role", "system"}, + {"content", system_content} + }); + } + + // Convert messages + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + const json & messages = body.at("messages"); + if (messages.is_array()) { + for (const auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + + if (!msg.contains("content")) { + if (role == "assistant") { + continue; + } + oai_messages.push_back(msg); + continue; + } + + const json & content = msg.at("content"); + + if (content.is_string()) { + oai_messages.push_back(msg); + continue; + } + + if (!content.is_array()) { + oai_messages.push_back(msg); + continue; + } + + json tool_calls = json::array(); + json converted_content = json::array(); + json tool_results = json::array(); + std::string reasoning_content; + bool has_tool_calls = false; + + for (const auto & block : content) { + std::string type = json_value(block, "type", std::string()); + + if (type == "text") { + converted_content.push_back(block); + } else if (type == "thinking") { + reasoning_content += json_value(block, "thinking", std::string()); + } else if (type == "image") { + json source = json_value(block, "source", json::object()); + std::string source_type = json_value(source, "type", std::string()); + + if (source_type == "base64") { + std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); + std::string data = json_value(source, "data", std::string()); + std::ostringstream ss; + ss << "data:" << media_type << ";base64," << data; + + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", ss.str()} + }} + }); + } else if (source_type == "url") { + std::string url = json_value(source, "url", std::string()); + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", url} + }} + }); + } + } else if (type == "tool_use") { + tool_calls.push_back({ + {"id", json_value(block, "id", std::string())}, + {"type", "function"}, + {"function", { + {"name", json_value(block, "name", std::string())}, + {"arguments", json_value(block, "input", json::object()).dump()} + }} + }); + has_tool_calls = true; + } else if (type == "tool_result") { + std::string tool_use_id = json_value(block, "tool_use_id", std::string()); + + auto result_content = json_value(block, "content", json()); + std::string result_text; + if (result_content.is_string()) { + result_text = result_content.get(); + } else if (result_content.is_array()) { + for (const auto & c : result_content) { + if (json_value(c, "type", std::string()) == "text") { + result_text += json_value(c, "text", std::string()); + } + } + } + + tool_results.push_back({ + {"role", "tool"}, + {"tool_call_id", tool_use_id}, + {"content", result_text} + }); + } + } + + if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) { + json new_msg = {{"role", role}}; + if (!converted_content.empty()) { + new_msg["content"] = converted_content; + } else if (has_tool_calls || !reasoning_content.empty()) { + new_msg["content"] = ""; + } + if (!tool_calls.empty()) { + new_msg["tool_calls"] = tool_calls; + } + if (!reasoning_content.empty()) { + new_msg["reasoning_content"] = reasoning_content; + } + oai_messages.push_back(new_msg); + } + + for (const auto & tool_msg : tool_results) { + oai_messages.push_back(tool_msg); + } + } + } + + oai_body["messages"] = oai_messages; + + // Convert tools + if (body.contains("tools")) { + const json & tools = body.at("tools"); + if (tools.is_array()) { + json oai_tools = json::array(); + for (const auto & tool : tools) { + oai_tools.push_back({ + {"type", "function"}, + {"function", { + {"name", json_value(tool, "name", std::string())}, + {"description", json_value(tool, "description", std::string())}, + {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} + }} + }); + } + oai_body["tools"] = oai_tools; + } + } + + // Convert tool_choice + if (body.contains("tool_choice")) { + const json & tc = body.at("tool_choice"); + if (tc.is_object()) { + std::string type = json_value(tc, "type", std::string()); + if (type == "auto") { + oai_body["tool_choice"] = "auto"; + } else if (type == "any" || type == "tool") { + oai_body["tool_choice"] = "required"; + } + } + } + + // Convert stop_sequences to stop + if (body.contains("stop_sequences")) { + oai_body["stop"] = body.at("stop_sequences"); + } + + // Handle max_tokens (required in Anthropic, but we're permissive) + if (body.contains("max_tokens")) { + oai_body["max_tokens"] = body.at("max_tokens"); + } else { + oai_body["max_tokens"] = 4096; + } + + // Pass through common params + for (const auto & key : {"temperature", "top_p", "top_k", "stream", "chat_template_kwargs"}) { + if (body.contains(key)) { + oai_body[key] = body.at(key); + } + } + + // Handle Anthropic-specific thinking param + if (body.contains("thinking")) { + json thinking = json_value(body, "thinking", json::object()); + std::string thinking_type = json_value(thinking, "type", std::string()); + if (thinking_type == "enabled") { + int budget_tokens = json_value(thinking, "budget_tokens", 10000); + oai_body["thinking_budget_tokens"] = budget_tokens; + } + } + + // Handle Anthropic-specific metadata param + if (body.contains("metadata")) { + json metadata = json_value(body, "metadata", json::object()); + std::string user_id = json_value(metadata, "user_id", std::string()); + if (!user_id.empty()) { + oai_body["__metadata_user_id"] = user_id; + } + } + + return oai_body; +} + +json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { + json delta = json::object(); + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + delta["content"] = diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } + if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + if (!diff.tool_call_delta.arguments.empty()) { + function["arguments"] = diff.tool_call_delta.arguments; + } + tool_call["function"] = function; + } + delta["tool_calls"] = json::array({ tool_call }); + } + return delta; +} + +json convert_transcriptions_to_chatcmpl( + const json & inp_body, + const common_chat_templates * tmpls, + const std::map & in_files, + std::vector & out_files) { + // TODO @ngxson : this function may need to be improved in the future + // handle input files + out_files.clear(); + auto it = in_files.find("file"); + if (it != in_files.end()) { + out_files.push_back(it->second); + } else { + throw std::invalid_argument("No input file found for transcription"); + } + + // handle input data + std::string prompt = json_value(inp_body, "prompt", std::string()); + std::string language = json_value(inp_body, "language", std::string()); + std::string response_format = json_value(inp_body, "response_format", std::string("json")); + if (response_format != "json") { + throw std::invalid_argument("Only 'json' response_format is supported for transcription"); + } + const common_chat_prompt_preset preset = common_chat_get_asr_prompt(tmpls); + if (prompt.empty()) { + prompt = preset.user; + } + if (!language.empty()) { + prompt += string_format(" (language: %s)", language.c_str()); + } + prompt += get_media_marker(); + + json messages = json::array(); + if (!preset.system.empty()) { + messages.push_back({{"role", "system"}, {"content", preset.system}}); + } + messages.push_back({{"role", "user"}, {"content", prompt}}); + + json chatcmpl_body = inp_body; // copy all fields + chatcmpl_body["messages"] = messages; + + // because input from form-data, everything is string, we need to correct the types here + std::string stream = json_value(inp_body, "stream", std::string("false")); + chatcmpl_body["stream"] = stream == "true"; + + if (inp_body.contains("max_tokens")) { + std::string inp = inp_body["max_tokens"].get(); + chatcmpl_body["max_tokens"] = std::stoul(inp); + } + + if (inp_body.contains("temperature")) { + std::string inp = inp_body["temperature"].get(); + chatcmpl_body["temperature"] = std::stof(inp); + } + + return chatcmpl_body; +} diff --git a/tools/server/server-chat.h b/tools/server/server-chat.h new file mode 100644 index 00000000000..5c5b792cf5d --- /dev/null +++ b/tools/server/server-chat.h @@ -0,0 +1,25 @@ +// Chat conversion functions for server (Responses API, Anthropic API, OAI streaming diffs) + +#pragma once + +#include "chat.h" +#include "server-common.h" + +#include + +using json = nlohmann::ordered_json; + +// Convert OpenAI Responses API format to OpenAI Chat Completions API format +json server_chat_convert_responses_to_chatcmpl(const json & body); + +// Convert Anthropic Messages API format to OpenAI Chat Completions API format +json server_chat_convert_anthropic_to_oai(const json & body); + +// convert OpenAI transcriptions API format to OpenAI Chat Completions API format +json convert_transcriptions_to_chatcmpl( + const json & body, + const common_chat_templates * tmpls, + const std::map & in_files, + std::vector & out_files); + +json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index f66b1f2557c..ad8834e317a 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -391,15 +391,25 @@ void server_tokens::push_back(server_tokens & tokens) { } void server_tokens::insert(const llama_tokens & inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); } -const llama_tokens & server_tokens::get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled +const llama_tokens & server_tokens::get_tokens() const { + GGML_ASSERT(!has_mtmd); return tokens; } +llama_tokens server_tokens::get_text_tokens() const { + llama_tokens res; + res.reserve(tokens.size()); + for (llama_token t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + res.push_back(t); + } + } + return res; +} + void server_tokens::set_token(llama_pos pos, llama_token id) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens[pos] = id; @@ -1017,6 +1027,8 @@ json oaicompat_chat_params_parse( } } + auto caps = common_chat_templates_get_caps(opt.tmpls.get()); + common_chat_templates_inputs inputs; inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = common_chat_tools_parse_oaicompat(tools); @@ -1024,7 +1036,7 @@ json oaicompat_chat_params_parse( inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.grammar = grammar; inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.reasoning_format = opt.reasoning_format; if (body.contains("reasoning_format")) { @@ -1154,573 +1166,6 @@ json oaicompat_chat_params_parse( return llama_params; } -json convert_responses_to_chatcmpl(const json & response_body) { - if (!response_body.contains("input")) { - throw std::invalid_argument("'input' is required"); - } - if (!json_value(response_body, "previous_response_id", std::string{}).empty()) { - throw std::invalid_argument("llama.cpp does not support 'previous_response_id'."); - } - - const json input_value = response_body.at("input"); - json chatcmpl_body = response_body; - chatcmpl_body.erase("input"); - std::vector chatcmpl_messages; - - if (response_body.contains("instructions")) { - chatcmpl_messages.push_back({ - {"role", "system"}, - {"content", json_value(response_body, "instructions", std::string())}, - }); - chatcmpl_body.erase("instructions"); - } - - if (input_value.is_string()) { - // #responses_create-input-text_input - chatcmpl_messages.push_back({ - {"role", "user"}, - {"content", input_value}, - }); - } else if (input_value.is_array()) { - // #responses_create-input-input_item_list - - static auto exists_and_is_array = [](const json & j, const char * key) -> bool { - return j.contains(key) && j.at(key).is_array(); - }; - static auto exists_and_is_string = [](const json & j, const char * key) -> bool { - return j.contains(key) && j.at(key).is_string(); - }; - - for (json item : input_value) { - bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant"; - - if (exists_and_is_string(item, "content")) { - // #responses_create-input-input_item_list-input_message-content-text_input - // Only "Input message" contains item["content"]::string - // After converting item["content"]::string to item["content"]::array, - // we can treat "Input message" as sum of "Item-Input message" and "Item-Output message" - item["content"] = json::array({ - json { - {"text", item.at("content")}, - {"type", "input_text"} - } - }); - } - - if (exists_and_is_array(item, "content") && - exists_and_is_string(item, "role") && - (item.at("role") == "user" || - item.at("role") == "system" || - item.at("role") == "developer") - ) { - // #responses_create-input-input_item_list-item-input_message - std::vector chatcmpl_content; - - for (const json & input_item : item.at("content")) { - const std::string type = json_value(input_item, "type", std::string()); - - if (type == "input_text") { - if (!input_item.contains("text")) { - throw std::invalid_argument("'Input text' requires 'text'"); - } - chatcmpl_content.push_back({ - {"text", input_item.at("text")}, - {"type", "text"}, - }); - } else if (type == "input_image") { - // While `detail` is marked as required, - // it has default value("auto") and can be omitted. - - if (!input_item.contains("image_url")) { - throw std::invalid_argument("'image_url' is required"); - } - chatcmpl_content.push_back({ - {"image_url", json { - {"url", input_item.at("image_url")} - }}, - {"type", "image_url"}, - }); - } else if (type == "input_file") { - throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment"); - // if (input_item.contains("file_url")) { - // // chat completion API does not support file_url - // throw std::invalid_argument("'file_url' is not supported"); - // } - // if (!input_item.contains("file_data") || !input_item.contains("filename")) { - // throw std::invalid_argument("Both 'file_data' and 'filename' are required"); - // } - // chatcmpl_content.push_back({ - // {"file", json { - // {"file_data", input_item.at("file_data")}, - // {"filename", input_item.at("filename")}, - // }}, - // {"type", "file"}, - // }); - } else { - throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'"); - } - } - - if (item.contains("type")) { - item.erase("type"); - } - if (item.contains("status")) { - item.erase("status"); - } - item["content"] = chatcmpl_content; - - chatcmpl_messages.push_back(item); - } else if (exists_and_is_array(item, "content") && - exists_and_is_string(item, "role") && - item.at("role") == "assistant" && - // exists_and_is_string(item, "status") && - // (item.at("status") == "in_progress" || - // item.at("status") == "completed" || - // item.at("status") == "incomplete") && - // item["status"] not sent by codex-cli - exists_and_is_string(item, "type") && - item.at("type") == "message" - ) { - // #responses_create-input-input_item_list-item-output_message - auto chatcmpl_content = json::array(); - - for (const auto & output_text : item.at("content")) { - const std::string type = json_value(output_text, "type", std::string()); - if (type == "output_text") { - if (!exists_and_is_string(output_text, "text")) { - throw std::invalid_argument("'Output text' requires 'text'"); - // Ignore annotations and logprobs for now - chatcmpl_content.push_back({ - {"text", output_text.at("text")}, - {"type", "text"}, - }); - } - } else if (type == "refusal") { - if (!exists_and_is_string(output_text, "refusal")) { - throw std::invalid_argument("'Refusal' requires 'refusal'"); - // Ignore annotations and logprobs for now - chatcmpl_content.push_back({ - {"refusal", output_text.at("refusal")}, - {"type", "refusal"}, - }); - } - } else { - throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'"); - } - } - - if (merge_prev) { - auto & prev_msg = chatcmpl_messages.back(); - if (!exists_and_is_array(prev_msg, "content")) { - prev_msg["content"] = json::array(); - } - auto & prev_content = prev_msg["content"]; - prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end()); - } else { - item.erase("status"); - item.erase("type"); - item["content"] = chatcmpl_content; - chatcmpl_messages.push_back(item); - } - } else if (exists_and_is_string(item, "arguments") && - exists_and_is_string(item, "call_id") && - exists_and_is_string(item, "name") && - exists_and_is_string(item, "type") && - item.at("type") == "function_call" - ) { - // #responses_create-input-input_item_list-item-function_tool_call - json tool_call = { - {"function", json { - {"arguments", item.at("arguments")}, - {"name", item.at("name")}, - }}, - {"id", item.at("call_id")}, - {"type", "function"}, - }; - - if (merge_prev) { - auto & prev_msg = chatcmpl_messages.back(); - if (!exists_and_is_array(prev_msg, "tool_calls")) { - prev_msg["tool_calls"] = json::array(); - } - prev_msg["tool_calls"].push_back(tool_call); - } else { - chatcmpl_messages.push_back(json { - {"role", "assistant"}, - {"tool_calls", json::array({tool_call})} - }); - } - } else if (exists_and_is_string(item, "call_id") && - (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && - exists_and_is_string(item, "type") && - item.at("type") == "function_call_output" - ) { - // #responses_create-input-input_item_list-item-function_tool_call_output - if (item.at("output").is_string()) { - chatcmpl_messages.push_back(json { - {"content", item.at("output")}, - {"role", "tool"}, - {"tool_call_id", item.at("call_id")}, - }); - } else { - json chatcmpl_outputs = item.at("output"); - for (json & chatcmpl_output : chatcmpl_outputs) { - if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") { - throw std::invalid_argument("Output of tool call should be 'Input text'"); - } - chatcmpl_output["type"] = "text"; - } - chatcmpl_messages.push_back(json { - {"content", chatcmpl_outputs}, - {"role", "tool"}, - {"tool_call_id", item.at("call_id")}, - }); - } - } else if (// exists_and_is_string(item, "id") && - // item["id"] not sent by codex-cli - exists_and_is_array(item, "summary") && - exists_and_is_string(item, "type") && - item.at("type") == "reasoning") { - // #responses_create-input-input_item_list-item-reasoning - - if (!exists_and_is_array(item, "content")) { - throw std::invalid_argument("item['content'] is not an array"); - } - if (item.at("content").empty()) { - throw std::invalid_argument("item['content'] is empty"); - } - if (!exists_and_is_string(item.at("content")[0], "text")) { - throw std::invalid_argument("item['content']['text'] is not a string"); - } - - if (merge_prev) { - auto & prev_msg = chatcmpl_messages.back(); - prev_msg["reasoning_content"] = item.at("content")[0].at("text"); - } else { - chatcmpl_messages.push_back(json { - {"role", "assistant"}, - {"content", json::array()}, - {"reasoning_content", item.at("content")[0].at("text")}, - }); - } - } else { - throw std::invalid_argument("Cannot determine type of 'item'"); - } - } - } else { - throw std::invalid_argument("'input' must be a string or array of objects"); - } - - chatcmpl_body["messages"] = chatcmpl_messages; - - if (response_body.contains("tools")) { - if (!response_body.at("tools").is_array()) { - throw std::invalid_argument("'tools' must be an array of objects"); - } - std::vector chatcmpl_tools; - for (json resp_tool : response_body.at("tools")) { - json chatcmpl_tool; - - if (json_value(resp_tool, "type", std::string()) != "function") { - throw std::invalid_argument("'type' of tool must be 'function'"); - } - resp_tool.erase("type"); - chatcmpl_tool["type"] = "function"; - - if (!resp_tool.contains("strict")) { - resp_tool["strict"] = true; - } - chatcmpl_tool["function"] = resp_tool; - chatcmpl_tools.push_back(chatcmpl_tool); - } - chatcmpl_body.erase("tools"); - chatcmpl_body["tools"] = chatcmpl_tools; - } - - if (response_body.contains("max_output_tokens")) { - chatcmpl_body.erase("max_output_tokens"); - chatcmpl_body["max_tokens"] = response_body["max_output_tokens"]; - } - - return chatcmpl_body; -} - -json convert_transcriptions_to_chatcmpl( - const json & inp_body, - const std::map & in_files, - std::vector & out_files) { - // TODO @ngxson : this function may need to be improved in the future - // handle input files - out_files.clear(); - auto it = in_files.find("file"); - if (it != in_files.end()) { - out_files.push_back(it->second); - } else { - throw std::invalid_argument("No input file found for transcription"); - } - - // handle input data - std::string prompt = json_value(inp_body, "prompt", std::string()); - std::string language = json_value(inp_body, "language", std::string()); - std::string response_format = json_value(inp_body, "response_format", std::string("json")); - if (response_format != "json") { - throw std::invalid_argument("Only 'json' response_format is supported for transcription"); - } - if (prompt.empty()) { - prompt = "Transcribe audio to text"; - } - if (!language.empty()) { - prompt += string_format(" (language: %s)", language.c_str()); - } - prompt += get_media_marker(); - - json chatcmpl_body = inp_body; // copy all fields - chatcmpl_body["messages"] = json::array({ - { - {"role", "user"}, - {"content", prompt}, - }, - }); - - // because input from form-data, everything is string, we need to correct the types here - std::string stream = json_value(inp_body, "stream", std::string("false")); - chatcmpl_body["stream"] = stream == "true"; - - if (inp_body.contains("max_tokens")) { - std::string inp = inp_body["max_tokens"].get(); - chatcmpl_body["max_tokens"] = std::stoul(inp); - } - - if (inp_body.contains("temperature")) { - std::string inp = inp_body["temperature"].get(); - chatcmpl_body["temperature"] = std::stof(inp); - } - - return chatcmpl_body; -} - -json convert_anthropic_to_oai(const json & body) { - json oai_body; - - // Convert system prompt - json oai_messages = json::array(); - auto system_param = json_value(body, "system", json()); - if (!system_param.is_null()) { - std::string system_content; - - if (system_param.is_string()) { - system_content = system_param.get(); - } else if (system_param.is_array()) { - for (const auto & block : system_param) { - if (json_value(block, "type", std::string()) == "text") { - system_content += json_value(block, "text", std::string()); - } - } - } - - oai_messages.push_back({ - {"role", "system"}, - {"content", system_content} - }); - } - - // Convert messages - if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); - } - const json & messages = body.at("messages"); - if (messages.is_array()) { - for (const auto & msg : messages) { - std::string role = json_value(msg, "role", std::string()); - - if (!msg.contains("content")) { - if (role == "assistant") { - continue; - } - oai_messages.push_back(msg); - continue; - } - - const json & content = msg.at("content"); - - if (content.is_string()) { - oai_messages.push_back(msg); - continue; - } - - if (!content.is_array()) { - oai_messages.push_back(msg); - continue; - } - - json tool_calls = json::array(); - json converted_content = json::array(); - json tool_results = json::array(); - std::string reasoning_content; - bool has_tool_calls = false; - - for (const auto & block : content) { - std::string type = json_value(block, "type", std::string()); - - if (type == "text") { - converted_content.push_back(block); - } else if (type == "thinking") { - reasoning_content += json_value(block, "thinking", std::string()); - } else if (type == "image") { - json source = json_value(block, "source", json::object()); - std::string source_type = json_value(source, "type", std::string()); - - if (source_type == "base64") { - std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); - std::string data = json_value(source, "data", std::string()); - std::ostringstream ss; - ss << "data:" << media_type << ";base64," << data; - - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", ss.str()} - }} - }); - } else if (source_type == "url") { - std::string url = json_value(source, "url", std::string()); - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", url} - }} - }); - } - } else if (type == "tool_use") { - tool_calls.push_back({ - {"id", json_value(block, "id", std::string())}, - {"type", "function"}, - {"function", { - {"name", json_value(block, "name", std::string())}, - {"arguments", json_value(block, "input", json::object()).dump()} - }} - }); - has_tool_calls = true; - } else if (type == "tool_result") { - std::string tool_use_id = json_value(block, "tool_use_id", std::string()); - - auto result_content = json_value(block, "content", json()); - std::string result_text; - if (result_content.is_string()) { - result_text = result_content.get(); - } else if (result_content.is_array()) { - for (const auto & c : result_content) { - if (json_value(c, "type", std::string()) == "text") { - result_text += json_value(c, "text", std::string()); - } - } - } - - tool_results.push_back({ - {"role", "tool"}, - {"tool_call_id", tool_use_id}, - {"content", result_text} - }); - } - } - - if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) { - json new_msg = {{"role", role}}; - if (!converted_content.empty()) { - new_msg["content"] = converted_content; - } else if (has_tool_calls || !reasoning_content.empty()) { - new_msg["content"] = ""; - } - if (!tool_calls.empty()) { - new_msg["tool_calls"] = tool_calls; - } - if (!reasoning_content.empty()) { - new_msg["reasoning_content"] = reasoning_content; - } - oai_messages.push_back(new_msg); - } - - for (const auto & tool_msg : tool_results) { - oai_messages.push_back(tool_msg); - } - } - } - - oai_body["messages"] = oai_messages; - - // Convert tools - if (body.contains("tools")) { - const json & tools = body.at("tools"); - if (tools.is_array()) { - json oai_tools = json::array(); - for (const auto & tool : tools) { - oai_tools.push_back({ - {"type", "function"}, - {"function", { - {"name", json_value(tool, "name", std::string())}, - {"description", json_value(tool, "description", std::string())}, - {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} - }} - }); - } - oai_body["tools"] = oai_tools; - } - } - - // Convert tool_choice - if (body.contains("tool_choice")) { - const json & tc = body.at("tool_choice"); - if (tc.is_object()) { - std::string type = json_value(tc, "type", std::string()); - if (type == "auto") { - oai_body["tool_choice"] = "auto"; - } else if (type == "any" || type == "tool") { - oai_body["tool_choice"] = "required"; - } - } - } - - // Convert stop_sequences to stop - if (body.contains("stop_sequences")) { - oai_body["stop"] = body.at("stop_sequences"); - } - - // Handle max_tokens (required in Anthropic, but we're permissive) - if (body.contains("max_tokens")) { - oai_body["max_tokens"] = body.at("max_tokens"); - } else { - oai_body["max_tokens"] = 4096; - } - - // Pass through common params - for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { - if (body.contains(key)) { - oai_body[key] = body.at(key); - } - } - - // Handle Anthropic-specific thinking param - if (body.contains("thinking")) { - json thinking = json_value(body, "thinking", json::object()); - std::string thinking_type = json_value(thinking, "type", std::string()); - if (thinking_type == "enabled") { - int budget_tokens = json_value(thinking, "budget_tokens", 10000); - oai_body["thinking_budget_tokens"] = budget_tokens; - } - } - - // Handle Anthropic-specific metadata param - if (body.contains("metadata")) { - json metadata = json_value(body, "metadata", json::object()); - std::string user_id = json_value(metadata, "user_id", std::string()); - if (!user_id.empty()) { - oai_body["__metadata_user_id"] = user_id; - } - } - - return oai_body; -} - json format_embeddings_response_oaicompat( const json & request, const std::string & model_name, diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 57545aa53ed..4681f9c5155 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -190,7 +190,9 @@ struct server_tokens { void insert(const llama_tokens & inp_tokens); // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens & get_text_tokens() const; + const llama_tokens & get_tokens() const; + + llama_tokens get_text_tokens() const; // for compatibility with speculative decoding void set_token(llama_pos pos, llama_token id); @@ -305,18 +307,6 @@ json oaicompat_chat_params_parse( const server_chat_params & opt, std::vector & out_files); -// convert OpenAI Responses API format to OpenAI Chat Completions API format -json convert_responses_to_chatcmpl(const json & body); - -// convert OpenAI transcriptions API format to OpenAI Chat Completions API format -json convert_transcriptions_to_chatcmpl( - const json & body, - const std::map & in_files, - std::vector & out_files); - -// convert Anthropic Messages API format to OpenAI Chat Completions API format -json convert_anthropic_to_oai(const json & body); - // TODO: move it to server-task.cpp json format_embeddings_response_oaicompat( const json & request, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 4b899ecf007..c835dd8a44c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,4 +1,6 @@ + #include "server-context.h" +#include "server-chat.h" #include "server-common.h" #include "server-http.h" #include "server-task.h" @@ -19,6 +21,7 @@ #include #include #include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -33,6 +36,31 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; +static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) { + if (pos_min == -1) { + pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); + } + if (pos_max == -1) { + pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto cur = server_prompt_checkpoint { + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.n_tokens = */ n_tokens, + /*.data = */ std::vector(checkpoint_size), + }; + + const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } + + return cur; +} + // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, @@ -51,13 +79,18 @@ enum server_state { struct server_slot { int id; - // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; + common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + // multimodal mtmd_context * mctx = nullptr; - common_speculative * spec = nullptr; + // speculative decoding + llama_tokens spec_draft; + std::vector spec_i_batch; + server_prompt_checkpoint spec_ckpt; + common_speculative_ptr spec; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -83,11 +116,6 @@ struct server_slot { std::string debug_generated_text; llama_tokens generated_tokens; - // idx of draft tokens in the main batch - // non-empty if we went to evaluate draft tokens - // ref: https://github.com/ggml-org/llama.cpp/pull/17808 - std::vector i_batch_dft; - std::vector generated_token_probs; bool has_next_token = true; @@ -147,8 +175,7 @@ struct server_slot { common_sampler_ptr smpl; - llama_token sampled; // in speculative mode, this is the last accepted token - llama_tokens drafted; + llama_token sampled; // in speculative mode, this is the last accepted token // stats size_t n_sent_text = 0; // number of sent text character @@ -178,8 +205,11 @@ struct server_slot { stopping_word = ""; n_sent_text = 0; - drafted.clear(); - i_batch_dft.clear(); + if (can_speculate()) { + spec_draft.clear(); + spec_i_batch.clear(); + spec_ckpt.clear(); + } generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -300,6 +330,83 @@ struct server_slot { return n_draft_max; } + void update_batch(llama_batch & batch) { + const int n_draft_max = get_n_draft_max(); + if (n_draft_max > 0) { + GGML_ASSERT(can_speculate()); + + // generate draft tokens in speculative decoding mode + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + const llama_tokens & tokens = prompt.tokens.get_text_tokens(); + + const auto & params_spec = task->params.speculative; + + if (!spec_draft.empty()) { + // we have a previous (partial) draft to reuse + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + GGML_ASSERT(!spec_ckpt.empty()); + } + } else { + GGML_ASSERT(spec_i_batch.empty()); + + // generate a new draft + spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); + + if (spec_draft.size() > (size_t) n_draft_max) { + SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); + spec_draft.resize(n_draft_max); + } + + if (spec_draft.size() < (size_t) params_spec.n_min) { + SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min); + spec_draft.clear(); + } + + if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + const auto n_tokens = prompt.tokens.size(); + + spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens); + + SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", + spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + } + } + + GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max); + } + + if (spec_draft.empty()) { + // no speculative decoding + i_batch = batch.n_tokens; + + common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true); + + SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", + sampled, n_ctx, prompt.n_tokens(), truncated); + } else { + SLT_DBG(*this, "generate_draft: id=%d, #tokens=%zu, #draft=%zu, pos_next=%d\n", + sampled, prompt.tokens.size(), spec_draft.size(), prompt.tokens.pos_next()); + + GGML_ASSERT(spec_i_batch.empty()); + + spec_i_batch.push_back(batch.n_tokens); + for (size_t i = 0; i < spec_draft.size(); i++) { + spec_i_batch.push_back(batch.n_tokens + i + 1); + } + + auto pos0 = prompt.tokens.pos_next(); + + common_batch_add(batch, sampled, pos0++, { this->id }, true); + for (auto token : spec_draft) { + common_batch_add(batch, token, pos0++, { this->id }, true); + } + } + + prompt.tokens.push_back(sampled); + prompt.tokens.insert(spec_draft); + } + void release() { if (is_processing()) { GGML_ASSERT(task); @@ -400,7 +507,7 @@ struct server_slot { ); } - common_speculative_print_stats(spec); + common_speculative_print_stats(spec.get()); } json to_json(bool only_metrics = false) const { @@ -568,6 +675,10 @@ struct server_context_impl { int32_t n_ctx; // total context for all clients / slots + // set to llama_model_n_swa(model) + // if swa_full is enabled, this is set to 0 to simulate a non-SWA model + int32_t n_swa; + // slots / clients std::vector slots; @@ -591,16 +702,17 @@ struct server_context_impl { void destroy() { llama_init.reset(); + ctx = nullptr; model = nullptr; mtmd_free(mctx); mctx = nullptr; - // Clear any sampling context for (server_slot & slot : slots) { - common_speculative_free(slot.spec); - slot.spec = nullptr; + if (slot.can_speculate()) { + slot.spec.reset(); + } } llama_batch_free(batch); @@ -611,7 +723,7 @@ struct server_context_impl { return; } SLT_INF(slot, "%s", "saving idle slot to prompt cache\n"); - SLT_DBG(slot, "%s", "__TEST_TAG_CLEAR_IDLE_SLOT__\n"); + SLT_DBG(slot, "%s", "__TEST_TAG_CACHE_IDLE_SLOT__\n"); slot.prompt_save(*prompt_cache); slot.prompt_clear(false); prompt_cache->update(); @@ -642,9 +754,6 @@ struct server_context_impl { llama_init = common_init_from_params(params_base); - // propagate model-metadata sampling defaults back to caller - params.sampling = params_base.sampling; - model = llama_init->model(); ctx = llama_init->context(); @@ -660,6 +769,7 @@ struct server_context_impl { add_bos_token = llama_vocab_get_add_bos(vocab); if (params_base.speculative.has_dft()) { + // TODO speculative: move to common/speculative.cpp? SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str()); const auto & params_spec = params_base.speculative; @@ -691,7 +801,28 @@ struct server_context_impl { } params_base.speculative.model_dft = model_dft.get(); + params_base.speculative.model_tgt = model; params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); + + if (params_base.speculative.eagle3) { + // EAGLE3 current limitation: extracted target features are per-context; multiple slots would overwrite each other + if (params_base.n_parallel > 1) { + SRV_ERR("%s", "EAGLE3 speculative decoding is not supported with n_parallel > 1\n"); + return false; + } + llama_set_eagle3(ctx, model_dft.get()); + SRV_INF("%s", "EAGLE3 feature extraction enabled on target model\n"); + } + + if (params_base.speculative.dflash) { + // DFlash current limitation: extracted target features are per-context; multiple slots would overwrite each other + if (params_base.n_parallel > 1) { + SRV_ERR("%s", "DFlash speculative decoding is not supported with n_parallel > 1\n"); + return false; + } + llama_set_dflash(ctx, model_dft.get()); + SRV_INF("%s", "DFlash feature extraction enabled on target model\n"); + } } std::string & mmproj_path = params_base.mmproj.path; @@ -727,11 +858,6 @@ struct server_context_impl { params_base.n_cache_reuse = 0; SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); } - - if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) { - params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; - SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled"); - } } if (!llama_memory_can_shift(llama_get_memory(ctx))) { @@ -753,6 +879,8 @@ struct server_context_impl { } } + n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model); + // Necessary similarity of prompt for slot selection slot_prompt_similarity = params_base.slot_prompt_similarity; @@ -769,33 +897,38 @@ struct server_context_impl { slots.clear(); - const bool can_spec = common_speculative_is_compat(ctx); - if (!can_spec) { + const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + SRV_WRN("%s", "speculative decoding will use checkpoints\n"); + } + // initialize slots for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; + slots.emplace_back(); + } + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot & slot = slots[i]; slot.id = i; slot.ctx = ctx; slot.n_ctx = n_ctx_slot; + slot.ctx_seq_rm_type = ctx_seq_rm_type; + slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (can_spec) { - slot.spec = common_speculative_init(params_base.speculative, slot.ctx); + if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); + if (slot.spec) { - if (mctx) { - SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); - return false; - } SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } else { - SLT_INF(slot, "%s", "speculative decoding context not initialized\n"); } } @@ -806,8 +939,6 @@ struct server_context_impl { }; slot.reset(); - - slots.push_back(std::move(slot)); } { @@ -854,6 +985,9 @@ struct server_context_impl { model_aliases = params_base.model_alias; model_tags = params_base.model_tags; + // propagate new defaults back to caller + params = params_base; + if (!is_resume) { return init(); } @@ -880,16 +1014,16 @@ struct server_context_impl { metrics.init(); - if (params_base.clear_idle) { + if (params_base.cache_idle_slots) { if (!params_base.kv_unified) { - SRV_WRN("%s: --clear-idle requires --kv-unified, disabling\n", __func__); - params_base.clear_idle = false; + SRV_WRN("%s: --cache-idle-slots requires --kv-unified, disabling\n", __func__); + params_base.cache_idle_slots = false; } else if (params_base.cache_ram_mib == 0) { - SRV_WRN("%s: --clear-idle requires --cache-ram, disabling\n", __func__); - params_base.clear_idle = false; + SRV_WRN("%s: --cache-idle-slots requires --cache-ram, disabling\n", __func__); + params_base.cache_idle_slots = false; } else { SRV_INF("%s: idle slots will be saved to prompt cache and cleared upon starting a new task\n", __func__); - SRV_DBG("%s", "__TEST_TAG_CLEAR_IDLE_ENABLED__\n"); + SRV_DBG("%s", "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__\n"); } } @@ -938,8 +1072,8 @@ struct server_context_impl { /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, - /* reasoning_budget */ params_base.reasoning_budget, - /* reasoning_budget_msg */ params_base.reasoning_budget_message, + /* reasoning_budget */ params_base.sampling.reasoning_budget_tokens, + /* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message, /* media_path */ params_base.media_path, /* force_pure_content */ params_base.force_pure_content_parser }; @@ -1197,7 +1331,7 @@ struct server_context_impl { backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0); + backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0); // TODO: getting post/pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_logits; @@ -1703,6 +1837,26 @@ struct server_context_impl { return true; } + // n_tokens_cur: the number of tokens added to the batch for the current slot + void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max)); + + SLT_WRN(slot, + "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, + cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + } + void process_single_task(server_task && task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: @@ -1759,7 +1913,7 @@ struct server_context_impl { break; // drop the task } - if (params_base.clear_idle) { + if (params_base.cache_idle_slots) { for (auto & s : slots) { if (!s.is_processing()) { slot_save_and_clear(s); @@ -1854,7 +2008,7 @@ struct server_context_impl { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); + const llama_tokens & tokens = slot->prompt.tokens.get_tokens(); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); @@ -2061,7 +2215,7 @@ struct server_context_impl { { GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy + llama_tokens new_tokens = slot.prompt.tokens.get_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } @@ -2100,61 +2254,7 @@ struct server_context_impl { continue; } - // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - const int n_draft_max = slot.get_n_draft_max(); - if (n_draft_max > 0) { - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - - const auto & params_spec = slot.task->params.speculative; - - llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); - - if (draft.size() > (size_t) n_draft_max) { - SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max); - draft.resize(n_draft_max); - } - - // add the sampled token to the batch - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); - - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - // fallback to normal decoding - slot.i_batch = slot.i_batch_dft[0]; - slot.drafted.clear(); - slot.i_batch_dft.clear(); - } else { - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // add all drafted tokens to the batch - for (size_t i = 0; i < draft.size(); i++) { - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(draft[i]); - } - slot.drafted = std::move(draft); - } - } else { - // no speculative decoding - slot.i_batch = batch.n_tokens; - - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - - slot.prompt.tokens.push_back(slot.sampled); - - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); - } + slot.update_batch(batch); } // process in chunks of params.n_batch @@ -2342,9 +2442,6 @@ struct server_context_impl { llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); - // note: when n_swa == 0, the model does not use SWA - const auto n_swa = std::max(0, llama_model_n_swa(model)); - // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, pos_next - n_swa); @@ -2515,15 +2612,11 @@ struct server_context_impl { // make a checkpoint of the parts of the memory that cannot be rolled back. // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type + // - the model does not support partial sequence removal + // - the model uses SWA (and we are not using `swa_full`) do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); + (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (n_swa > 0)); bool has_mtmd = false; @@ -2651,40 +2744,12 @@ struct server_context_impl { // no need to create checkpoints that are too close together do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64); + SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // note: we create the checkpoint before calling llama_decode(), so the current batch is not // yet processed and therefore it is not part of the checkpoint. if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, - "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 - ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = - llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, - LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, - "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 - ", size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + create_checkpoint(slot, n_tokens_cur, pos_min, pos_max); } } @@ -2856,19 +2921,19 @@ struct server_context_impl { slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } - if (slot.i_batch_dft.size() > 0) { + if (slot.can_speculate() && !slot.spec_draft.empty()) { continue; // sample using speculative decoding } const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); slot.i_batch = -1; @@ -2889,7 +2954,7 @@ struct server_context_impl { completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2909,43 +2974,91 @@ struct server_context_impl { // speculative decoding - main model sample and accept for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) { + if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { continue; } - const size_t n_draft = slot.drafted.size(); + // save the original draft size + const size_t n_draft = slot.spec_draft.size(); + + GGML_ASSERT(n_draft > 0); - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); - slot.i_batch_dft.clear(); - slot.drafted.clear(); + // verify and try to accept the draft + { + const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + // only save the sampler sampler state if we use checkpoints + common_sampler_ptr smpl_save; + if (use_ckpt) { + smpl_save.reset(common_sampler_clone(slot.smpl.get())); + } + + GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + slot.spec_i_batch.clear(); + + SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size()); + + GGML_ASSERT(accepted.size() >= 1); + + // check for partial draft acceptance + if (accepted.size() < slot.spec_draft.size() + 1) { + if (use_ckpt) { + // partial acceptance is not supported by the context -> truncate the draft and restore the state + slot.spec_draft = std::move(accepted); + + const auto & ckpt = slot.spec_ckpt; + + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", + ckpt.pos_min, ckpt.pos_max, ckpt.size()); + + const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt.size()) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + } + + llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); + + slot.prompt.tokens.keep_first(ckpt.n_tokens); + slot.smpl = std::move(smpl_save); + + continue; + } + + LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size()); + } + + common_speculative_accept(slot.spec.get(), accepted.size() - 1); + + slot.spec_draft = std::move(accepted); + } const int64_t t_current = ggml_time_us(); - slot.n_decoded += ids.size(); + const auto ids = std::move(slot.spec_draft); + slot.n_decoded += ids.size(); slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - - // inform the speculative decoding about the number of accepted tokens - common_speculative_accept(slot.spec, ids.size() - 1); - - // rollback to the state before sampling the draft tokens - slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + slot.n_draft_total += n_draft; // add accepted tokens to the prompt + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3537,6 +3650,7 @@ void server_routes::init_routes() { {"vision", meta->has_inp_image}, {"audio", meta->has_inp_audio}, } }, + { "media_marker", get_media_marker() }, { "endpoint_slots", params.endpoint_slots }, { "endpoint_props", params.endpoint_props }, { "endpoint_metrics", params.endpoint_metrics }, @@ -3570,34 +3684,6 @@ void server_routes::init_routes() { return res; }; - this->get_api_show = [this](const server_http_req &) { - auto res = create_response(); - std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), ""); - json data = { - { - "model_info", { - { "llama.context_length", meta->slot_n_ctx }, - } - }, - {"modelfile", ""}, - {"parameters", ""}, - {"template", tmpl_default}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }}, - {"model_info", ""}, - {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} - }; - - res->ok(data); - return res; - }; - this->post_infill = [this](const server_http_req & req) { auto res = create_response(); // check model compatibility @@ -3664,7 +3750,7 @@ void server_routes::init_routes() { params.n_predict, meta->slot_n_ctx, params.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. + tokenized_prompts[0].get_tokens() // TODO: this could maybe be multimodal. ); std::vector files; // dummy @@ -3719,7 +3805,7 @@ void server_routes::init_routes() { this->post_responses_oai = [this](const server_http_req & req) { auto res = create_response(); std::vector files; - json body = convert_responses_to_chatcmpl(json::parse(req.body)); + json body = server_chat_convert_responses_to_chatcmpl(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: OpenAI Responses -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( @@ -3745,6 +3831,7 @@ void server_routes::init_routes() { std::vector files; json body = convert_transcriptions_to_chatcmpl( json::parse(req.body), + meta->chat_params.tmpls.get(), req.files, files); SRV_DBG("%s\n", "Request converted: OpenAI Transcriptions -> OpenAI Chat Completions"); @@ -3764,7 +3851,7 @@ void server_routes::init_routes() { this->post_anthropic_messages = [this](const server_http_req & req) { auto res = create_response(); std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); + json body = server_chat_convert_anthropic_to_oai(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( @@ -3782,7 +3869,7 @@ void server_routes::init_routes() { this->post_anthropic_count_tokens = [this](const server_http_req & req) { auto res = create_response(); std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); + json body = server_chat_convert_anthropic_to_oai(json::parse(req.body)); SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 6856043fad6..37f10dc7792 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -105,7 +105,6 @@ struct server_routes { server_http_context::handler_t post_slots; server_http_context::handler_t get_props; server_http_context::handler_t post_props; - server_http_context::handler_t get_api_show; server_http_context::handler_t post_infill; server_http_context::handler_t post_completions; server_http_context::handler_t post_completions_oai; diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 83f656f5c9d..ae39fbff9bd 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -143,7 +143,6 @@ bool server_http_context::init(const common_params & params) { "/v1/health", "/models", "/v1/models", - "/api/tags", "/", "/index.html", "/bundle.js", diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index a1eeec30e99..15c11c3c9fb 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -712,6 +712,11 @@ void server_models::unload(const std::string & name) { if (it->second.meta.is_running()) { SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); + if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) { + // special case: if model is in loading state, unloading means force-killing it + SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str()); + subprocess_terminate(it->second.subproc.get()); + } cv_stop.notify_all(); // status change will be handled by the managing thread } else { @@ -1147,7 +1152,7 @@ server_http_proxy::server_http_proxy( // setup Client cli->set_follow_location(true); - cli->set_connection_timeout(5, 0); // 5 seconds + cli->set_connection_timeout(timeout_read, 0); // use --timeout value instead of hardcoded 5 s cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server) cli->set_read_timeout(timeout_write, 0); this->status = 500; // to be overwritten upon response diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 4fb953b4920..4c341d7c50f 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1,6 +1,7 @@ #include "server-task.h" #include "build-info.h" +#include "server-chat.h" #include "chat.h" #include "common.h" #include "json-schema-to-grammar.h" @@ -162,7 +163,7 @@ common_chat_msg task_result_state::update_chat_msg( bool filter_tool_calls) { generated_text += text_added; auto msg_prv_copy = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + //SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); auto new_msg = common_chat_parse( generated_text, is_partial, @@ -269,6 +270,7 @@ task_params server_task::params_from_json_cmpl( params.n_indent = json_value(data, "n_indent", defaults.n_indent); params.n_keep = json_value(data, "n_keep", defaults.n_keep); params.n_discard = json_value(data, "n_discard", defaults.n_discard); + params.n_discard = std::max(0, params.n_discard); params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1)); params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement @@ -304,6 +306,8 @@ task_params server_task::params_from_json_cmpl( params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + params.speculative = defaults.speculative; + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); @@ -871,7 +875,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { json { {"finish_reason", nullptr}, {"index", index}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + {"delta", server_chat_msg_diff_to_json_oaicompat(diff)}, }, })}, {"created", t}, @@ -1108,7 +1112,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { json server_task_result_cmpl_final::to_json_oaicompat_asr() { json event = json { {"type", "transcript.text.done"}, - {"text", content}, + {"text", oaicompat_msg.content}, {"usage", json { {"type", "tokens"}, {"input_tokens", n_prompt_tokens}, @@ -1520,7 +1524,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() { } for (const auto & diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + add_delta(server_chat_msg_diff_to_json_oaicompat(diff)); } if (!deltas.empty()) { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 95f39207b18..289e1fb8d24 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -576,6 +576,17 @@ struct server_prompt_checkpoint { size_t size() const { return data.size(); } + + bool empty() const { + return data.empty(); + } + + void clear() { + pos_min = 0; + pos_max = 0; + n_tokens = 0; + data.clear(); + } }; struct server_prompt { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 06318463fd4..6566949edf1 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -7,6 +7,7 @@ #include "arg.h" #include "build-info.h" #include "common.h" +#include "fit.h" #include "llama.h" #include "log.h" @@ -141,7 +142,6 @@ int main(int argc, char ** argv) { // note: routes.get_health stays the same routes.get_metrics = models_routes->proxy_get; routes.post_props = models_routes->proxy_post; - routes.get_api_show = models_routes->proxy_get; routes.post_completions = models_routes->proxy_post; routes.post_completions_oai = models_routes->proxy_post; routes.post_chat_completions = models_routes->proxy_post; @@ -174,16 +174,13 @@ int main(int argc, char ** argv) { ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); ctx_http.get ("/props", ex_wrapper(routes.get_props)); ctx_http.post("/props", ex_wrapper(routes.post_props)); - ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) - ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy ctx_http.post("/completions", ex_wrapper(routes.post_completions)); ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); - ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai)); ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai)); ctx_http.post("/v1/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai)); @@ -348,7 +345,7 @@ int main(int argc, char ** argv) { auto * ll_ctx = ctx_server.get_llama_context(); if (ll_ctx != nullptr) { - llama_memory_breakdown_print(ll_ctx); + common_memory_breakdown_print(ll_ctx); } } diff --git a/tools/server/tests/unit/test_kv_keep_only_active.py b/tools/server/tests/unit/test_kv_keep_only_active.py index da93d50011e..44c05fab0cb 100644 --- a/tools/server/tests/unit/test_kv_keep_only_active.py +++ b/tools/server/tests/unit/test_kv_keep_only_active.py @@ -48,7 +48,7 @@ def test_clear_and_restore(): log = LogReader(server.log_path) # verify feature is enabled - assert "__TEST_TAG_CLEAR_IDLE_ENABLED__" in log.drain() + assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" in log.drain() res = server.make_request("POST", "/completion", data={ "prompt": LONG_PROMPT, @@ -59,7 +59,7 @@ def test_clear_and_restore(): original_prompt_n = res.body["timings"]["prompt_n"] # Slot 0 is the only slot with KV — should NOT be cleared - assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain() + assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain() # Launching slot 1 clears idle slot 0 res = server.make_request("POST", "/completion", data={ @@ -68,7 +68,7 @@ def test_clear_and_restore(): "cache_prompt": True, }) assert res.status_code == 200 - assert "__TEST_TAG_CLEAR_IDLE_SLOT__" in log.drain() + assert "__TEST_TAG_CACHE_IDLE_SLOT__" in log.drain() # Re-send same prompt — should restore from cache-ram res = server.make_request("POST", "/completion", data={ @@ -86,17 +86,17 @@ def test_clear_and_restore(): "cache_prompt": True, }) assert res.status_code == 200 - assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain() + assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain() def test_disabled_with_flag(): global server - server.no_clear_idle = True + server.no_cache_idle_slots = True server.start() log = LogReader(server.log_path) # Feature should not be enabled - assert "__TEST_TAG_CLEAR_IDLE_ENABLED__" not in log.drain() + assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" not in log.drain() res = server.make_request("POST", "/completion", data={ "prompt": LONG_PROMPT, @@ -112,4 +112,4 @@ def test_disabled_with_flag(): "cache_prompt": True, }) assert res.status_code == 200 - assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain() + assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain() diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 5ddac5be496..ddbb76c9adb 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -103,7 +103,7 @@ class ServerProcess: media_path: str | None = None sleep_idle_seconds: int | None = None cache_ram: int | None = None - no_clear_idle: bool = False + no_cache_idle_slots: bool = False log_path: str | None = None webui_mcp_proxy: bool = False @@ -242,8 +242,8 @@ def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds]) if self.cache_ram is not None: server_args.extend(["--cache-ram", self.cache_ram]) - if self.no_clear_idle: - server_args.append("--no-clear-idle") + if self.no_cache_idle_slots: + server_args.append("--no-cache-idle-slots") if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") diff --git a/ty.toml b/ty.toml index bcd23db9b8b..a07d7485d43 100644 --- a/ty.toml +++ b/ty.toml @@ -1,5 +1,5 @@ [environment] -extra-paths = ["./gguf-py", "./examples/model-conversion/scripts", "./tools/server/tests"] +extra-paths = ["./gguf-py", "./examples/model-conversion/scripts", "./tools/server/tests", "./scripts/snapdragon/qdc/tests"] python-version = "3.10" [rules] @@ -13,6 +13,7 @@ exclude = [ [[overrides]] include = [ "./tools/server/tests/**", + "./scripts/snapdragon/qdc/tests/**", ] [overrides.rules] diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 28485a0ce80..df4b9ecce3f 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -81,7 +81,7 @@ if (LLAMA_BUILD_BORINGSSL) target_link_libraries(${TARGET} PUBLIC ssl crypto) elseif (LLAMA_BUILD_LIBRESSL) - set(LIBRESSL_VERSION "4.2.1" CACHE STRING "LibreSSL version") + set(LIBRESSL_VERSION "4.3.1" CACHE STRING "LibreSSL version") message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}") @@ -161,12 +161,24 @@ if(LLAMA_BUILD_BORINGSSL OR LLAMA_BUILD_LIBRESSL) if(LLAMA_BUILD_BORINGSSL) target_compile_options(fipsmodule PRIVATE /w) endif() + if(LLAMA_BUILD_LIBRESSL) + target_compile_options(ssl_obj PRIVATE /w) + target_compile_options(bs_obj PRIVATE /w) + target_compile_options(compat_obj PRIVATE /w) + target_compile_options(crypto_obj PRIVATE /w) + endif() else() target_compile_options(ssl PRIVATE -w) target_compile_options(crypto PRIVATE -w) if(LLAMA_BUILD_BORINGSSL) target_compile_options(fipsmodule PRIVATE -w) endif() + if(LLAMA_BUILD_LIBRESSL) + target_compile_options(ssl_obj PRIVATE -w) + target_compile_options(bs_obj PRIVATE -w) + target_compile_options(compat_obj PRIVATE -w) + target_compile_options(crypto_obj PRIVATE -w) + endif() endif() endif() diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 8ff1da57bb5..95bf0eb1bb5 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1,7 +1,5 @@ #include "httplib.h" namespace httplib { -// httplib::any — type-erased value container (C++11 compatible) -// On C++17+ builds, thin wrappers around std::any are provided. /* * Implementation that will be part of the .cc file if split into .h + .cc. @@ -874,7 +872,8 @@ bool write_websocket_frame(Stream &strm, ws::Opcode opcode, if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } uint8_t ext[8]; for (int i = 7; i >= 0; i--) { - ext[7 - i] = static_cast((len >> (i * 8)) & 0xFF); + ext[7 - i] = + static_cast((static_cast(len) >> (i * 8)) & 0xFF); } if (strm.write(reinterpret_cast(ext), 8) < 0) { return false; } } @@ -1036,10 +1035,15 @@ bool canonicalize_path(const char *path, std::string &resolved) { char buf[_MAX_PATH]; if (_fullpath(buf, path, _MAX_PATH) == nullptr) { return false; } resolved = buf; -#else +#elif defined(PATH_MAX) char buf[PATH_MAX]; if (realpath(path, buf) == nullptr) { return false; } resolved = buf; +#else + auto buf = realpath(path, nullptr); + auto guard = scope_exit([&]() { std::free(buf); }); + if (buf == nullptr) { return false; } + resolved = buf; #endif return true; } @@ -1877,7 +1881,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service, } return ret; -#elif TARGET_OS_MAC +#elif TARGET_OS_MAC && defined(__clang__) if (!node) { return EAI_NONAME; } // macOS implementation using CFHost API for asynchronous DNS resolution CFStringRef hostname_ref = CFStringCreateWithCString( @@ -2767,6 +2771,35 @@ EncodingType encoding_type(const Request &req, const Response &res) { return best; } +std::unique_ptr make_compressor(EncodingType type) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (type == EncodingType::Gzip) { + return detail::make_unique(); + } +#endif +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + if (type == EncodingType::Brotli) { + return detail::make_unique(); + } +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (type == EncodingType::Zstd) { + return detail::make_unique(); + } +#endif + (void)type; + return nullptr; +} + +const char *encoding_name(EncodingType type) { + switch (type) { + case EncodingType::Gzip: return "gzip"; + case EncodingType::Brotli: return "br"; + case EncodingType::Zstd: return "zstd"; + default: return ""; + } +} + bool nocompressor::compress(const char *data, size_t data_length, bool /*last*/, Callback callback) { if (!data_length) { return true; } @@ -3099,6 +3132,29 @@ const char *get_header_value(const Headers &headers, return def; } +size_t get_header_value_count(const Headers &headers, + const std::string &key) { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +template +typename Map::mapped_type +get_multimap_value(const Map &m, const std::string &key, size_t id) { + auto rng = m.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return typename Map::mapped_type(); +} + +void set_header(Headers &headers, const std::string &key, + const std::string &val) { + if (fields::is_field_name(key) && fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + bool read_headers(Stream &strm, Headers &headers) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -5793,16 +5849,12 @@ std::string Request::get_header_value(const std::string &key, } size_t Request::get_header_value_count(const std::string &key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + return detail::get_header_value_count(headers, key); } void Request::set_header(const std::string &key, const std::string &val) { - if (detail::fields::is_field_name(key) && - detail::fields::is_field_value(val)) { - headers.emplace(key, val); - } + detail::set_header(headers, key, val); } bool Request::has_trailer(const std::string &key) const { @@ -5811,11 +5863,7 @@ bool Request::has_trailer(const std::string &key) const { std::string Request::get_trailer_value(const std::string &key, size_t id) const { - auto rng = trailers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + return detail::get_multimap_value(trailers, key, id); } size_t Request::get_trailer_value_count(const std::string &key) const { @@ -5829,11 +5877,18 @@ bool Request::has_param(const std::string &key) const { std::string Request::get_param_value(const std::string &key, size_t id) const { + return detail::get_multimap_value(params, key, id); +} + +std::vector +Request::get_param_values(const std::string &key) const { auto rng = params.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + std::vector values; + values.reserve(static_cast(std::distance(rng.first, rng.second))); + for (auto it = rng.first; it != rng.second; ++it) { + values.push_back(it->second); + } + return values; } size_t Request::get_param_value_count(const std::string &key) const { @@ -5877,11 +5932,7 @@ size_t MultipartFormData::get_field_count(const std::string &key) const { FormData MultipartFormData::get_file(const std::string &key, size_t id) const { - auto rng = files.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return FormData(); + return detail::get_multimap_value(files, key, id); } std::vector @@ -5920,16 +5971,12 @@ std::string Response::get_header_value(const std::string &key, } size_t Response::get_header_value_count(const std::string &key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + return detail::get_header_value_count(headers, key); } void Response::set_header(const std::string &key, const std::string &val) { - if (detail::fields::is_field_name(key) && - detail::fields::is_field_value(val)) { - headers.emplace(key, val); - } + detail::set_header(headers, key, val); } bool Response::has_trailer(const std::string &key) const { return trailers.find(key) != trailers.end(); @@ -5937,11 +5984,7 @@ bool Response::has_trailer(const std::string &key) const { std::string Response::get_trailer_value(const std::string &key, size_t id) const { - auto rng = trailers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + return detail::get_multimap_value(trailers, key, id); } size_t Response::get_trailer_value_count(const std::string &key) const { @@ -6244,15 +6287,6 @@ void ThreadPool::worker(bool is_dynamic) { assert(true == static_cast(fn)); fn(); - - // Dynamic thread: exit if queue is empty after task completion - if (is_dynamic) { - std::unique_lock lock(mutex_); - if (jobs_.empty()) { - move_to_finished(std::this_thread::get_id()); - break; - } - } } #if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ @@ -6782,61 +6816,51 @@ Server::make_matcher(const std::string &pattern) { } Server &Server::Get(const std::string &pattern, Handler handler) { - get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(get_handlers_, pattern, std::move(handler)); } Server &Server::Post(const std::string &pattern, Handler handler) { - post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(post_handlers_, pattern, std::move(handler)); } Server &Server::Post(const std::string &pattern, HandlerWithContentReader handler) { - post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(post_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Put(const std::string &pattern, Handler handler) { - put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(put_handlers_, pattern, std::move(handler)); } Server &Server::Put(const std::string &pattern, HandlerWithContentReader handler) { - put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(put_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Patch(const std::string &pattern, Handler handler) { - patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(patch_handlers_, pattern, std::move(handler)); } Server &Server::Patch(const std::string &pattern, HandlerWithContentReader handler) { - patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(patch_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Delete(const std::string &pattern, Handler handler) { - delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(delete_handlers_, pattern, std::move(handler)); } Server &Server::Delete(const std::string &pattern, HandlerWithContentReader handler) { - delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(delete_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Options(const std::string &pattern, Handler handler) { - options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(options_handlers_, pattern, std::move(handler)); } Server &Server::WebSocket(const std::string &pattern, @@ -7013,6 +7037,15 @@ Server &Server::set_keep_alive_timeout(time_t sec) { return *this; } +template +Server &Server::set_keep_alive_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t /*usec*/) { + set_keep_alive_timeout(sec); + }); + return *this; +} + Server &Server::set_read_timeout(time_t sec, time_t usec) { read_timeout_sec_ = sec; read_timeout_usec_ = usec; @@ -7036,6 +7069,11 @@ Server &Server::set_payload_max_length(size_t length) { return *this; } +Server &Server::set_websocket_max_missed_pongs(int count) { + websocket_max_missed_pongs_ = count; + return *this; +} + Server &Server::set_websocket_ping_interval(time_t sec) { websocket_ping_interval_sec_ = sec; return *this; @@ -7261,23 +7299,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req, if (res.is_chunked_content_provider_) { auto type = detail::encoding_type(req, res); - std::unique_ptr compressor; - if (type == detail::EncodingType::Gzip) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = detail::make_unique(); -#endif - } else if (type == detail::EncodingType::Brotli) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = detail::make_unique(); -#endif - } else if (type == detail::EncodingType::Zstd) { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - compressor = detail::make_unique(); -#endif - } else { + auto compressor = detail::make_compressor(type); + if (!compressor) { compressor = detail::make_unique(); } - assert(compressor != nullptr); return detail::write_content_chunked(strm, res.content_provider_, is_shutting_down, *compressor); @@ -7899,14 +7924,8 @@ void Server::apply_ranges(const Request &req, Response &res, if (res.content_provider_) { if (res.is_chunked_content_provider_) { res.set_header("Transfer-Encoding", "chunked"); - if (type == detail::EncodingType::Gzip) { - res.set_header("Content-Encoding", "gzip"); - res.set_header("Vary", "Accept-Encoding"); - } else if (type == detail::EncodingType::Brotli) { - res.set_header("Content-Encoding", "br"); - res.set_header("Vary", "Accept-Encoding"); - } else if (type == detail::EncodingType::Zstd) { - res.set_header("Content-Encoding", "zstd"); + if (type != detail::EncodingType::None) { + res.set_header("Content-Encoding", detail::encoding_name(type)); res.set_header("Vary", "Accept-Encoding"); } } @@ -7937,27 +7956,7 @@ void Server::apply_ranges(const Request &req, Response &res, if (type != detail::EncodingType::None) { output_pre_compression_log(req, res); - std::unique_ptr compressor; - std::string content_encoding; - - if (type == detail::EncodingType::Gzip) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = detail::make_unique(); - content_encoding = "gzip"; -#endif - } else if (type == detail::EncodingType::Brotli) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = detail::make_unique(); - content_encoding = "br"; -#endif - } else if (type == detail::EncodingType::Zstd) { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - compressor = detail::make_unique(); - content_encoding = "zstd"; -#endif - } - - if (compressor) { + if (auto compressor = detail::make_compressor(type)) { std::string compressed; if (compressor->compress(res.body.data(), res.body.size(), true, [&](const char *data, size_t data_len) { @@ -7965,7 +7964,7 @@ void Server::apply_ranges(const Request &req, Response &res, return true; })) { res.body.swap(compressed); - res.set_header("Content-Encoding", content_encoding); + res.set_header("Content-Encoding", detail::encoding_name(type)); res.set_header("Vary", "Accept-Encoding"); } } @@ -8213,7 +8212,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr, { // Use WebSocket-specific read timeout instead of HTTP timeout strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0); - ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_); + ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_, + websocket_max_missed_pongs_); entry.handler(req, ws); } return true; @@ -9119,20 +9119,21 @@ bool ClientImpl::redirect(Request &req, Response &res, Error &error) { auto location = res.get_header_value("location"); if (location.empty()) { return false; } - thread_local const std::regex re( - R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + detail::UrlComponents uc; + if (!detail::parse_url(location, uc)) { return false; } - std::smatch m; - if (!std::regex_match(location, m, re)) { return false; } + // Only follow http/https redirects + if (!uc.scheme.empty() && uc.scheme != "http" && uc.scheme != "https") { + return false; + } auto scheme = is_ssl() ? "https" : "http"; - auto next_scheme = m[1].str(); - auto next_host = m[2].str(); - if (next_host.empty()) { next_host = m[3].str(); } - auto port_str = m[4].str(); - auto next_path = m[5].str(); - auto next_query = m[6].str(); + auto next_scheme = std::move(uc.scheme); + auto next_host = std::move(uc.host); + auto port_str = std::move(uc.port); + auto next_path = std::move(uc.path); + auto next_query = std::move(uc.query); auto next_port = port_; if (!port_str.empty()) { @@ -9145,7 +9146,7 @@ bool ClientImpl::redirect(Request &req, Response &res, Error &error) { if (next_host.empty()) { next_host = host_; } if (next_path.empty()) { next_path = "/"; } - auto path = decode_query_component(next_path, true) + next_query; + auto path = decode_path_component(next_path) + next_query; // Same host redirect - use current client if (next_scheme == scheme && next_host == host_ && next_port == port_) { @@ -10803,38 +10804,6 @@ void ClientImpl::enable_server_hostname_verification(bool enabled) { } #endif -// ClientImpl::set_ca_cert_store is defined after TLS namespace (uses helpers) -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, - std::size_t size) const { - auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); - auto se = detail::scope_exit([&] { BIO_free_all(mem); }); - if (!mem) { return nullptr; } - - auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); - if (!inf) { return nullptr; } - - auto cts = X509_STORE_new(); - if (cts) { - for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { - auto itmp = sk_X509_INFO_value(inf, i); - if (!itmp) { continue; } - - if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } - if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } - } - } - - sk_X509_INFO_pop_free(inf, X509_INFO_free); - return cts; -} - -void ClientImpl::set_server_certificate_verifier( - std::function /*verifier*/) { - // Base implementation does nothing - SSLClient overrides this -} -#endif - void ClientImpl::set_logger(Logger logger) { logger_ = std::move(logger); } @@ -10869,12 +10838,9 @@ Client::Client(const std::string &scheme_host_port) Client::Client(const std::string &scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { - const static std::regex re( - R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); - - std::smatch m; - if (std::regex_match(scheme_host_port, m, re)) { - auto scheme = m[1].str(); + detail::UrlComponents uc; + if (detail::parse_url(scheme_host_port, uc) && !uc.host.empty()) { + auto &scheme = uc.scheme; #ifdef CPPHTTPLIB_SSL_ENABLED if (!scheme.empty() && (scheme != "http" && scheme != "https")) { @@ -10890,12 +10856,10 @@ Client::Client(const std::string &scheme_host_port, auto is_ssl = scheme == "https"; - auto host = m[2].str(); - if (host.empty()) { host = m[3].str(); } + auto host = std::move(uc.host); - auto port_str = m[4].str(); auto port = is_ssl ? 443 : 80; - if (!port_str.empty() && !detail::parse_port(port_str, port)) { return; } + if (!uc.port.empty() && !detail::parse_port(uc.port, port)) { return; } if (is_ssl) { #ifdef CPPHTTPLIB_SSL_ENABLED @@ -10913,10 +10877,10 @@ Client::Client(const std::string &scheme_host_port, cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); } -} // namespace detail +} Client::Client(const std::string &host, int port) - : cli_(detail::make_unique(host, port)) {} + : Client(host, port, std::string(), std::string()) {} Client::Client(const std::string &host, int port, const std::string &client_cert_path, @@ -11491,12 +11455,6 @@ void Client::set_follow_location(bool on) { void Client::set_path_encode(bool on) { cli_->set_path_encode(on); } -[[deprecated("Use set_path_encode() instead. " - "This function will be removed by v1.0.0.")]] -void Client::set_url_encode(bool on) { - cli_->set_path_encode(on); -} - void Client::set_compress(bool on) { cli_->set_compress(on); } void Client::set_decompress(bool on) { cli_->set_decompress(on); } @@ -11879,24 +11837,31 @@ SSLClient::SSLClient(const std::string &host) SSLClient::SSLClient(const std::string &host, int port) : SSLClient(host, port, std::string(), std::string()) {} +void SSLClient::init_ctx() { + ctx_ = tls::create_client_context(); + if (ctx_) { tls::set_min_version(ctx_, tls::Version::TLS1_2); } +} + +void SSLClient::reset_ctx_on_error() { + last_backend_error_ = tls::get_error(); + tls::free_context(ctx_); + ctx_ = nullptr; +} + SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path, const std::string &private_key_password) : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = tls::create_client_context(); + init_ctx(); if (!ctx_) { return; } - tls::set_min_version(ctx_, tls::Version::TLS1_2); - if (!client_cert_path.empty() && !client_key_path.empty()) { const char *password = private_key_password.empty() ? nullptr : private_key_password.c_str(); if (!tls::set_client_cert_file(ctx_, client_cert_path.c_str(), client_key_path.c_str(), password)) { - last_backend_error_ = tls::get_error(); - tls::free_context(ctx_); - ctx_ = nullptr; + reset_ctx_on_error(); } } } @@ -11904,17 +11869,13 @@ SSLClient::SSLClient(const std::string &host, int port, SSLClient::SSLClient(const std::string &host, int port, const PemMemory &pem) : ClientImpl(host, port) { - ctx_ = tls::create_client_context(); + init_ctx(); if (!ctx_) { return; } - tls::set_min_version(ctx_, tls::Version::TLS1_2); - if (pem.cert_pem && pem.key_pem) { if (!tls::set_client_cert_pem(ctx_, pem.cert_pem, pem.key_pem, pem.private_key_password)) { - last_backend_error_ = tls::get_error(); - tls::free_context(ctx_); - ctx_ = nullptr; + reset_ctx_on_error(); } } } @@ -12465,23 +12426,6 @@ std::string Request::sni() const { * Group 8: TLS abstraction layer - OpenSSL backend */ -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -SSL_CTX *Client::ssl_context() const { - if (is_ssl_) { return static_cast(*cli_).ssl_context(); } - return nullptr; -} - -void Client::set_server_certificate_verifier( - std::function verifier) { - cli_->set_server_certificate_verifier(verifier); -} - -long Client::get_verify_result() const { - if (is_ssl_) { return static_cast(*cli_).get_verify_result(); } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? -} -#endif // CPPHTTPLIB_OPENSSL_SUPPORT - /* * OpenSSL Backend Implementation */ @@ -12491,54 +12435,6 @@ namespace tls { namespace impl { -// OpenSSL-specific helpers for converting native types to PEM -std::string x509_to_pem(X509 *cert) { - if (!cert) return {}; - BIO *bio = BIO_new(BIO_s_mem()); - if (!bio) return {}; - if (PEM_write_bio_X509(bio, cert) != 1) { - BIO_free(bio); - return {}; - } - char *data = nullptr; - long len = BIO_get_mem_data(bio, &data); - std::string pem(data, static_cast(len)); - BIO_free(bio); - return pem; -} - -std::string evp_pkey_to_pem(EVP_PKEY *key) { - if (!key) return {}; - BIO *bio = BIO_new(BIO_s_mem()); - if (!bio) return {}; - if (PEM_write_bio_PrivateKey(bio, key, nullptr, nullptr, 0, nullptr, - nullptr) != 1) { - BIO_free(bio); - return {}; - } - char *data = nullptr; - long len = BIO_get_mem_data(bio, &data); - std::string pem(data, static_cast(len)); - BIO_free(bio); - return pem; -} - -std::string x509_store_to_pem(X509_STORE *store) { - if (!store) return {}; - std::string pem; - auto objs = X509_STORE_get0_objects(store); - if (!objs) return {}; - auto count = sk_X509_OBJECT_num(objs); - for (decltype(count) i = 0; i < count; i++) { - auto obj = sk_X509_OBJECT_value(objs, i); - if (X509_OBJECT_get_type(obj) == X509_LU_X509) { - auto cert = X509_OBJECT_get0_X509(obj); - if (cert) { pem += x509_to_pem(cert); } - } - } - return pem; -} - // Helper to map OpenSSL SSL_get_error to ErrorCode ErrorCode map_ssl_error(int ssl_error, int &out_errno) { switch (ssl_error) { @@ -12571,8 +12467,10 @@ STACK_OF(X509_NAME) * X509 *cert = nullptr; while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != nullptr) { - X509_NAME *name = X509_get_subject_name(cert); - if (name) { sk_X509_NAME_push(ca_list, X509_NAME_dup(name)); } + const X509_NAME *name = X509_get_subject_name(cert); + if (name) { + sk_X509_NAME_push(ca_list, X509_NAME_dup(const_cast(name))); + } X509_free(cert); } BIO_free(bio); @@ -12580,45 +12478,6 @@ STACK_OF(X509_NAME) * return ca_list; } -// Helper: Extract CA names from X509_STORE -// Returns a new STACK_OF(X509_NAME)* or nullptr on failure -// Caller takes ownership of returned list -STACK_OF(X509_NAME) * - extract_client_ca_list_from_store(X509_STORE *store) { - if (!store) { return nullptr; } - - auto ca_list = sk_X509_NAME_new_null(); - if (!ca_list) { return nullptr; } - - auto objs = X509_STORE_get0_objects(store); - if (!objs) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - auto count = sk_X509_OBJECT_num(objs); - for (decltype(count) i = 0; i < count; i++) { - auto obj = sk_X509_OBJECT_value(objs, i); - if (X509_OBJECT_get_type(obj) == X509_LU_X509) { - auto cert = X509_OBJECT_get0_X509(obj); - if (cert) { - auto subject = X509_get_subject_name(cert); - if (subject) { - auto name_dup = X509_NAME_dup(subject); - if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); } - } - } - } - } - - if (sk_X509_NAME_num(ca_list) == 0) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - return ca_list; -} - // OpenSSL verify callback wrapper int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { auto &callback = get_verify_callback(); @@ -13054,6 +12913,9 @@ ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { auto ssl_err = SSL_get_error(ssl, ret); err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::PeerClosed) { + return 0; + } // Gracefully handle the peer closed state. if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } return -1; } @@ -13491,164 +13353,8 @@ std::string verify_error_string(long error_code) { return str ? str : "unknown error"; } -namespace impl { - -// OpenSSL-specific helpers for public API wrappers -ctx_t create_server_context_from_x509(X509 *cert, EVP_PKEY *key, - X509_STORE *client_ca_store, - int &out_error) { - out_error = 0; - auto cert_pem = x509_to_pem(cert); - auto key_pem = evp_pkey_to_pem(key); - if (cert_pem.empty() || key_pem.empty()) { - out_error = static_cast(ERR_get_error()); - return nullptr; - } - - auto ctx = create_server_context(); - if (!ctx) { - out_error = static_cast(get_error()); - return nullptr; - } - - if (!set_server_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr)) { - out_error = static_cast(get_error()); - free_context(ctx); - return nullptr; - } - - if (client_ca_store) { - // Set cert store for verification (SSL_CTX_set_cert_store takes ownership) - SSL_CTX_set_cert_store(static_cast(ctx), client_ca_store); - - // Extract and set client CA list directly from store (more efficient than - // PEM conversion) - auto ca_list = extract_client_ca_list_from_store(client_ca_store); - if (ca_list) { - SSL_CTX_set_client_CA_list(static_cast(ctx), ca_list); - } - - set_verify_client(ctx, true); - } - - return ctx; -} - -void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key, - X509_STORE *client_ca_store) { - auto cert_pem = x509_to_pem(cert); - auto key_pem = evp_pkey_to_pem(key); - - if (!cert_pem.empty() && !key_pem.empty()) { - update_server_cert(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr); - } - - if (client_ca_store) { - auto ca_pem = x509_store_to_pem(client_ca_store); - if (!ca_pem.empty()) { update_server_client_ca(ctx, ca_pem.c_str()); } - X509_STORE_free(client_ca_store); - } -} - -ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key, - const char *password, - uint64_t &out_error) { - out_error = 0; - auto ctx = create_client_context(); - if (!ctx) { - out_error = get_error(); - return nullptr; - } - - if (cert && key) { - auto cert_pem = x509_to_pem(cert); - auto key_pem = evp_pkey_to_pem(key); - if (cert_pem.empty() || key_pem.empty()) { - out_error = ERR_get_error(); - free_context(ctx); - return nullptr; - } - if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), - password)) { - out_error = get_error(); - free_context(ctx); - return nullptr; - } - } - - return ctx; -} - -} // namespace impl - } // namespace tls -// ClientImpl::set_ca_cert_store - defined here to use -// tls::impl::x509_store_to_pem Deprecated: converts X509_STORE to PEM and -// stores for redirect transfer -void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { - ca_cert_pem_ = tls::impl::x509_store_to_pem(ca_cert_store); - } -} - -SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - ctx_ = tls::impl::create_server_context_from_x509( - cert, private_key, client_ca_cert_store, last_ssl_error_); -} - -SSLServer::SSLServer( - const std::function &setup_ssl_ctx_callback) { - // Use abstract API to create context - ctx_ = tls::create_server_context(); - if (ctx_) { - // Pass to OpenSSL-specific callback (ctx_ is SSL_CTX* internally) - auto ssl_ctx = static_cast(ctx_); - if (!setup_ssl_ctx_callback(*ssl_ctx)) { - tls::free_context(ctx_); - ctx_ = nullptr; - } - } -} - -SSL_CTX *SSLServer::ssl_context() const { - return static_cast(ctx_); -} - -void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - std::lock_guard guard(ctx_mutex_); - tls::impl::update_server_certs_from_x509(ctx_, cert, private_key, - client_ca_cert_store); -} - -SSLClient::SSLClient(const std::string &host, int port, - X509 *client_cert, EVP_PKEY *client_key, - const std::string &private_key_password) - : ClientImpl(host, port) { - const char *password = - private_key_password.empty() ? nullptr : private_key_password.c_str(); - ctx_ = tls::impl::create_client_context_from_x509( - client_cert, client_key, password, last_backend_error_); -} - -long SSLClient::get_verify_result() const { return verify_result_; } - -void SSLClient::set_server_certificate_verifier( - std::function verifier) { - // Wrap SSL* callback into backend-independent session_verifier_ - auto v = std::make_shared>( - std::move(verifier)); - session_verifier_ = [v](tls::session_t session) { - return (*v)(static_cast(session)); - }; -} - -SSL_CTX *SSLClient::ssl_context() const { - return static_cast(ctx_); -} - bool SSLClient::verify_host(X509 *server_cert) const { /* Quote from RFC2818 section 3.1 "Server Identity" @@ -16162,7 +15868,11 @@ ReadResult WebSocket::read(std::string &msg) { payload.size(), true, !is_server_); continue; } - case Opcode::Pong: continue; + case Opcode::Pong: { + std::lock_guard lock(ping_mutex_); + unacked_pings_ = 0; + continue; + } case Opcode::Close: { if (!closed_.exchange(true)) { // Echo close frame back @@ -16196,7 +15906,11 @@ ReadResult WebSocket::read(std::string &msg) { true, !is_server_); continue; } - if (cont_opcode == Opcode::Pong) { continue; } + if (cont_opcode == Opcode::Pong) { + std::lock_guard lock(ping_mutex_); + unacked_pings_ = 0; + continue; + } if (cont_opcode == Opcode::Close) { if (!closed_.exchange(true)) { std::lock_guard lock(write_mutex_); @@ -16284,12 +15998,22 @@ void WebSocket::start_heartbeat() { while (!closed_) { ping_cv_.wait_for(lock, std::chrono::seconds(ping_interval_sec_)); if (closed_) { break; } + // If the peer has failed to respond to the previous pings, give up. + // RFC 6455 does not define a pong-timeout mechanism; this is an + // opt-in liveness check controlled by max_missed_pongs_. + if (max_missed_pongs_ > 0 && unacked_pings_ >= max_missed_pongs_) { + lock.unlock(); + close(CloseStatus::GoingAway, "pong timeout"); + return; + } lock.unlock(); if (!send_frame(Opcode::Ping, nullptr, 0)) { + lock.lock(); closed_ = true; break; } lock.lock(); + unacked_pings_++; } }); } @@ -16302,12 +16026,10 @@ bool WebSocket::is_open() const { return !closed_; } WebSocketClient::WebSocketClient( const std::string &scheme_host_port_path, const Headers &headers) : headers_(headers) { - const static std::regex re( - R"(([a-z]+):\/\/(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?(\/.*))"); - - std::smatch m; - if (std::regex_match(scheme_host_port_path, m, re)) { - auto scheme = m[1].str(); + detail::UrlComponents uc; + if (detail::parse_url(scheme_host_port_path, uc) && !uc.scheme.empty() && + !uc.host.empty() && !uc.path.empty()) { + auto &scheme = uc.scheme; #ifdef CPPHTTPLIB_SSL_ENABLED if (scheme != "ws" && scheme != "wss") { @@ -16323,14 +16045,12 @@ WebSocketClient::WebSocketClient( auto is_ssl = scheme == "wss"; - host_ = m[2].str(); - if (host_.empty()) { host_ = m[3].str(); } + host_ = std::move(uc.host); - auto port_str = m[4].str(); port_ = is_ssl ? 443 : 80; - if (!port_str.empty() && !detail::parse_port(port_str, port_)) { return; } + if (!uc.port.empty() && !detail::parse_port(uc.port, port_)) { return; } - path_ = m[5].str(); + path_ = std::move(uc.path); #ifdef CPPHTTPLIB_SSL_ENABLED is_ssl_ = is_ssl; @@ -16421,8 +16141,9 @@ bool WebSocketClient::connect() { Request req; req.method = "GET"; req.path = path_; - ws_ = std::unique_ptr( - new WebSocket(std::move(strm), req, false, websocket_ping_interval_sec_)); + ws_ = std::unique_ptr(new WebSocket(std::move(strm), req, false, + websocket_ping_interval_sec_, + websocket_max_missed_pongs_)); return true; } @@ -16466,6 +16187,10 @@ void WebSocketClient::set_websocket_ping_interval(time_t sec) { websocket_ping_interval_sec_ = sec; } +void WebSocketClient::set_websocket_max_missed_pongs(int count) { + websocket_max_missed_pongs_ = count; +} + void WebSocketClient::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } void WebSocketClient::set_address_family(int family) { diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 2967ddf5e50..8581d1695a8 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.40.0" -#define CPPHTTPLIB_VERSION_NUM "0x002800" +#define CPPHTTPLIB_VERSION "0.43.1" +#define CPPHTTPLIB_VERSION_NUM "0x002b01" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 @@ -205,6 +205,10 @@ #define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30 #endif +#ifndef CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS +#define CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS 0 +#endif + /* * Headers */ @@ -333,13 +337,10 @@ using socket_t = int; #include #include #include -#if __cplusplus >= 201703L -#include -#endif // On macOS with a TLS backend, enable Keychain root certificates by default // unless the user explicitly opts out. -#if defined(__APPLE__) && \ +#if defined(__APPLE__) && defined(__clang__) && \ !defined(CPPHTTPLIB_DISABLE_MACOSX_AUTOMATIC_ROOT_CERTIFICATES) && \ (defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \ defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || \ @@ -358,7 +359,7 @@ using socket_t = int; #if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \ defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) -#if TARGET_OS_MAC +#if TARGET_OS_MAC && defined(__clang__) #include #include #endif @@ -701,9 +702,96 @@ inline bool parse_port(const std::string &s, int &port) { return parse_port(s.data(), s.size(), port); } +struct UrlComponents { + std::string scheme; + std::string host; + std::string port; + std::string path; + std::string query; +}; + +inline bool parse_url(const std::string &url, UrlComponents &uc) { + uc = {}; + size_t pos = 0; + + auto sep = url.find("://"); + if (sep != std::string::npos) { + uc.scheme = url.substr(0, sep); + + // Scheme must be [a-z]+ only + if (uc.scheme.empty()) { return false; } + for (auto c : uc.scheme) { + if (c < 'a' || c > 'z') { return false; } + } + + pos = sep + 3; + } else if (url.compare(0, 2, "//") == 0) { + pos = 2; + } + + auto has_authority_prefix = pos > 0; + auto has_authority = has_authority_prefix || (!url.empty() && url[0] != '/' && + url[0] != '?' && url[0] != '#'); + if (has_authority) { + if (pos < url.size() && url[pos] == '[') { + auto close = url.find(']', pos); + if (close == std::string::npos) { return false; } + uc.host = url.substr(pos + 1, close - pos - 1); + + // IPv6 host must be [a-fA-F0-9:]+ only + if (uc.host.empty()) { return false; } + for (auto c : uc.host) { + if (!((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') || + (c >= '0' && c <= '9') || c == ':')) { + return false; + } + } + + pos = close + 1; + } else { + auto end = url.find_first_of(":/?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.host = url.substr(pos, end - pos); + pos = end; + } + + if (pos < url.size() && url[pos] == ':') { + ++pos; + auto end = url.find_first_of("/?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.port = url.substr(pos, end - pos); + pos = end; + } + + // Without :// or //, the entire input must be consumed as host[:port]. + // If there is leftover (path, query, etc.), this is not a valid + // host[:port] string — clear and reparse as a plain path. + if (!has_authority_prefix && pos < url.size()) { + uc.host.clear(); + uc.port.clear(); + pos = 0; + } + } + + if (pos < url.size() && url[pos] != '?' && url[pos] != '#') { + auto end = url.find_first_of("?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.path = url.substr(pos, end - pos); + pos = end; + } + + if (pos < url.size() && url[pos] == '?') { + auto end = url.find('#', pos); + if (end == std::string::npos) { end = url.size(); } + uc.query = url.substr(pos, end - pos); + } + + return true; +} + } // namespace detail -enum SSLVerifierResponse { +enum class SSLVerifierResponse { // no decision has been made, use the built-in certificate verifier NoDecisionMade, // connection certificate is verified and accepted @@ -797,38 +885,15 @@ using Match = std::smatch; using DownloadProgress = std::function; using UploadProgress = std::function; - -#if __cplusplus >= 201703L - -using any = std::any; -using bad_any_cast = std::bad_any_cast; - -template T any_cast(const any &a) { return std::any_cast(a); } -template T any_cast(any &a) { return std::any_cast(a); } -template T any_cast(any &&a) { - return std::any_cast(std::move(a)); -} -template const T *any_cast(const any *a) noexcept { - return std::any_cast(a); -} -template T *any_cast(any *a) noexcept { - return std::any_cast(a); -} - -#else // C++11/14 implementation - -class bad_any_cast : public std::bad_cast { -public: - const char *what() const noexcept override { return "bad any_cast"; } -}; - +/* + * detail: type-erased storage used by UserData. + * ABI-stable regardless of C++ standard — always uses this custom + * implementation instead of std::any. + */ namespace detail { using any_type_id = const void *; -// Returns a unique per-type ID without RTTI. -// The static address is stable across TUs because function templates are -// implicitly inline and the ODR merges their statics into one. template any_type_id any_typeid() noexcept { static const char id = 0; return &id; @@ -851,88 +916,59 @@ template struct any_value final : any_storage { } // namespace detail -class any { - std::unique_ptr storage_; - +class UserData { public: - any() noexcept = default; - any(const any &o) : storage_(o.storage_ ? o.storage_->clone() : nullptr) {} - any(any &&) noexcept = default; - any &operator=(const any &o) { - storage_ = o.storage_ ? o.storage_->clone() : nullptr; - return *this; + UserData() = default; + UserData(UserData &&) noexcept = default; + UserData &operator=(UserData &&) noexcept = default; + + UserData(const UserData &o) { + for (const auto &e : o.entries_) { + if (e.second) { entries_[e.first] = e.second->clone(); } + } } - any &operator=(any &&) noexcept = default; - - template < - typename T, typename D = typename std::decay::type, - typename std::enable_if::value, int>::type = 0> - any(T &&v) : storage_(new detail::any_value(std::forward(v))) {} - - template < - typename T, typename D = typename std::decay::type, - typename std::enable_if::value, int>::type = 0> - any &operator=(T &&v) { - storage_.reset(new detail::any_value(std::forward(v))); + + UserData &operator=(const UserData &o) { + if (this != &o) { + entries_.clear(); + for (const auto &e : o.entries_) { + if (e.second) { entries_[e.first] = e.second->clone(); } + } + } return *this; } - bool has_value() const noexcept { return storage_ != nullptr; } - void reset() noexcept { storage_.reset(); } - - template friend T *any_cast(any *a) noexcept; - template friend const T *any_cast(const any *a) noexcept; -}; + template void set(const std::string &key, T &&value) { + using D = typename std::decay::type; + entries_[key].reset(new detail::any_value(std::forward(value))); + } -template T *any_cast(any *a) noexcept { - if (!a || !a->storage_) { return nullptr; } - if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } - return &static_cast *>(a->storage_.get())->value; -} + template T *get(const std::string &key) noexcept { + auto it = entries_.find(key); + if (it == entries_.end() || !it->second) { return nullptr; } + if (it->second->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(it->second.get())->value; + } -template const T *any_cast(const any *a) noexcept { - if (!a || !a->storage_) { return nullptr; } - if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } - return &static_cast *>(a->storage_.get())->value; -} + template const T *get(const std::string &key) const noexcept { + auto it = entries_.find(key); + if (it == entries_.end() || !it->second) { return nullptr; } + if (it->second->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(it->second.get())->value; + } -template T any_cast(const any &a) { - using U = - typename std::remove_cv::type>::type; - const U *p = any_cast(&a); -#ifndef CPPHTTPLIB_NO_EXCEPTIONS - if (!p) { throw bad_any_cast{}; } -#else - if (!p) { std::abort(); } -#endif - return static_cast(*p); -} + bool has(const std::string &key) const noexcept { + return entries_.find(key) != entries_.end(); + } -template T any_cast(any &a) { - using U = - typename std::remove_cv::type>::type; - U *p = any_cast(&a); -#ifndef CPPHTTPLIB_NO_EXCEPTIONS - if (!p) { throw bad_any_cast{}; } -#else - if (!p) { std::abort(); } -#endif - return static_cast(*p); -} + void erase(const std::string &key) { entries_.erase(key); } -template T any_cast(any &&a) { - using U = - typename std::remove_cv::type>::type; - U *p = any_cast(&a); -#ifndef CPPHTTPLIB_NO_EXCEPTIONS - if (!p) { throw bad_any_cast{}; } -#else - if (!p) { std::abort(); } -#endif - return static_cast(std::move(*p)); -} + void clear() noexcept { entries_.clear(); } -#endif // __cplusplus >= 201703L +private: + std::unordered_map> + entries_; +}; struct Response; using ResponseHandler = std::function; @@ -1261,6 +1297,7 @@ struct Request { bool has_param(const std::string &key) const; std::string get_param_value(const std::string &key, size_t id = 0) const; + std::vector get_param_values(const std::string &key) const; size_t get_param_value_count(const std::string &key) const; bool is_multipart_form_data() const; @@ -1293,7 +1330,7 @@ struct Response { // User-defined context — set by pre-routing/pre-request handlers and read // by route handlers to pass arbitrary data (e.g. decoded auth tokens). - std::map user_data; + UserData user_data; bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", @@ -1664,6 +1701,9 @@ class Server { Server &set_keep_alive_max_count(size_t count); Server &set_keep_alive_timeout(time_t sec); + template + Server & + set_keep_alive_timeout(const std::chrono::duration &duration); Server &set_read_timeout(time_t sec, time_t usec = 0); template @@ -1684,6 +1724,8 @@ class Server { Server &set_websocket_ping_interval( const std::chrono::duration &duration); + Server &set_websocket_max_missed_pongs(int count); + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); int bind_to_any_port(const std::string &host, int socket_flags = 0); bool listen_after_bind(); @@ -1720,6 +1762,7 @@ class Server { size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; time_t websocket_ping_interval_sec_ = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND; + int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS; private: using Handlers = @@ -1731,6 +1774,14 @@ class Server { static std::unique_ptr make_matcher(const std::string &pattern); + template + Server &add_handler( + std::vector, H>> &handlers, + const std::string &pattern, H handler) { + handlers.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; + } + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); Server &set_error_handler_core(Handler handler, std::false_type); @@ -1892,15 +1943,6 @@ class Result { int ssl_error_ = 0; uint64_t ssl_backend_error_ = 0; #endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use ssl_backend_error() instead. " - "This function will be removed by v1.0.0.")]] - uint64_t ssl_openssl_error() const { - return ssl_backend_error_; - } -#endif }; struct ClientConnection { @@ -2373,22 +2415,6 @@ class ClientImpl { int last_ssl_error_ = 0; uint64_t last_backend_error_ = 0; #endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use load_ca_cert_store() instead. " - "This function will be removed by v1.0.0.")]] - void set_ca_cert_store(X509_STORE *ca_cert_store); - - [[deprecated("Use tls::create_ca_store() instead. " - "This function will be removed by v1.0.0.")]] - X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; - - [[deprecated("Use set_server_certificate_verifier(VerifyCallback) instead. " - "This function will be removed by v1.0.0.")]] - virtual void set_server_certificate_verifier( - std::function verifier); -#endif }; class Client { @@ -2563,7 +2589,6 @@ class Client { void set_follow_location(bool on); void set_path_encode(bool on); - void set_url_encode(bool on); void set_compress(bool on); @@ -2611,22 +2636,6 @@ class Client { private: bool is_ssl_ = false; #endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use tls_context() instead. " - "This function will be removed by v1.0.0.")]] - SSL_CTX *ssl_context() const; - - [[deprecated("Use set_session_verifier(session_t) instead. " - "This function will be removed by v1.0.0.")]] - void set_server_certificate_verifier( - std::function verifier); - - [[deprecated("Use Result::ssl_backend_error() instead. " - "This function will be removed by v1.0.0.")]] - long get_verify_result() const; -#endif }; #ifdef CPPHTTPLIB_SSL_ENABLED @@ -2672,29 +2681,6 @@ class SSLServer : public Server { std::mutex ctx_mutex_; int last_ssl_error_ = 0; - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use SSLServer(PemMemory) or " - "SSLServer(ContextSetupCallback) instead. " - "This constructor will be removed by v1.0.0.")]] - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); - - [[deprecated("Use SSLServer(ContextSetupCallback) instead. " - "This constructor will be removed by v1.0.0.")]] - SSLServer( - const std::function &setup_ssl_ctx_callback); - - [[deprecated("Use tls_context() instead. " - "This function will be removed by v1.0.0.")]] - SSL_CTX *ssl_context() const; - - [[deprecated("Use update_certs_pem() instead. " - "This function will be removed by v1.0.0.")]] - void update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); -#endif }; class SSLClient final : public ClientImpl { @@ -2758,6 +2744,9 @@ class SSLClient final : public ClientImpl { Response &res, bool &success, Error &error); bool initialize_ssl(Socket &socket, Error &error); + void init_ctx(); + void reset_ctx_on_error(); + bool load_certs(); tls::ctx_t ctx_ = nullptr; @@ -2775,26 +2764,6 @@ class SSLClient final : public ClientImpl { friend class ClientImpl; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -public: - [[deprecated("Use SSLClient(host, port, PemMemory) instead. " - "This constructor will be removed by v1.0.0.")]] - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key, - const std::string &private_key_password = std::string()); - - [[deprecated("Use Result::ssl_backend_error() instead. " - "This function will be removed by v1.0.0.")]] - long get_verify_result() const; - - [[deprecated("Use tls_context() instead. " - "This function will be removed by v1.0.0.")]] - SSL_CTX *ssl_context() const; - - [[deprecated("Use set_session_verifier(session_t) instead. " - "This function will be removed by v1.0.0.")]] - void set_server_certificate_verifier( - std::function verifier) override; - private: bool verify_host(X509 *server_cert) const; bool verify_host_with_subject_alt_name(X509 *server_cert) const; @@ -3766,17 +3735,21 @@ class WebSocket { WebSocket( Stream &strm, const Request &req, bool is_server, - time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND) + time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND, + int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS) : strm_(strm), req_(req), is_server_(is_server), - ping_interval_sec_(ping_interval_sec) { + ping_interval_sec_(ping_interval_sec), + max_missed_pongs_(max_missed_pongs) { start_heartbeat(); } WebSocket( std::unique_ptr &&owned_strm, const Request &req, bool is_server, - time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND) + time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND, + int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS) : strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req), - is_server_(is_server), ping_interval_sec_(ping_interval_sec) { + is_server_(is_server), ping_interval_sec_(ping_interval_sec), + max_missed_pongs_(max_missed_pongs) { start_heartbeat(); } @@ -3788,6 +3761,8 @@ class WebSocket { Request req_; bool is_server_; time_t ping_interval_sec_; + int max_missed_pongs_; + int unacked_pings_ = 0; std::atomic closed_{false}; std::mutex write_mutex_; std::thread ping_thread_; @@ -3817,6 +3792,7 @@ class WebSocketClient { void set_read_timeout(time_t sec, time_t usec = 0); void set_write_timeout(time_t sec, time_t usec = 0); void set_websocket_ping_interval(time_t sec); + void set_websocket_max_missed_pongs(int count); void set_tcp_nodelay(bool on); void set_address_family(int family); void set_ipv6_v6only(bool on); @@ -3848,6 +3824,7 @@ class WebSocketClient { time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; time_t websocket_ping_interval_sec_ = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND; + int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS; int address_family_ = AF_UNSPEC; bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY;