Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions .github/workflows/release-prism.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
name: Release (Prism)

on:
workflow_dispatch:
inputs:
create_release:
description: 'Create new release'
required: true
type: boolean

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true

env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
CMAKE_ARGS: "-DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_TOOLS=ON -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON"

jobs:
macOS-arm64:
runs-on: macos-14

steps:
- name: Clone
uses: actions/checkout@v6
with:
fetch-depth: 0

- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: macOS-latest-cmake-arm64
evict-old-files: 1d

- name: Build
run: |
cmake -B build \
-DCMAKE_INSTALL_RPATH='@loader_path' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DGGML_RPC=ON \
${{ env.CMAKE_ARGS }}
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)

- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name

- name: Pack artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .

- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
name: llama-bin-macos-arm64.tar.gz

linux-cuda:
runs-on: ubuntu-22.04

strategy:
matrix:
include:
- cuda: '12.4'
cuda_pkg: '12-4'
- cuda: '12.8'
cuda_pkg: '12-8'
- cuda: '13.1'
cuda_pkg: '13-1'

steps:
- name: Clone
uses: actions/checkout@v6
with:
fetch-depth: 0

- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-22-cmake-cuda-${{ matrix.cuda }}
evict-old-files: 1d

- name: Install CUDA toolkit
run: |
wget -q https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get -y install cuda-toolkit-${{ matrix.cuda_pkg }}
echo "/usr/local/cuda-${{ matrix.cuda }}/bin" >> $GITHUB_PATH
echo "CUDA_PATH=/usr/local/cuda-${{ matrix.cuda }}" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/usr/local/cuda-${{ matrix.cuda }}/lib64:$LD_LIBRARY_PATH" >> $GITHUB_ENV

- name: Build
run: |
cmake -B build \
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DGGML_NATIVE=OFF \
-DGGML_CUDA=ON \
${{ env.CMAKE_ARGS }}
cmake --build build --config Release -j $(nproc) 2>&1 | grep -v "^nvcc warning"

- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name

- name: Pack artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-linux-cuda-${{ matrix.cuda }}-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .

- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-linux-cuda-${{ matrix.cuda }}-x64.tar.gz
name: llama-bin-linux-cuda-${{ matrix.cuda }}-x64.tar.gz

windows-cuda:
runs-on: windows-2022

strategy:
matrix:
cuda: ['12.4', '13.1']

steps:
- name: Clone
uses: actions/checkout@v6

- name: Install ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: windows-cuda-${{ matrix.cuda }}
variant: ccache
evict-old-files: 1d

- name: Install Cuda Toolkit
uses: ./.github/actions/windows-setup-cuda
with:
cuda_version: ${{ matrix.cuda }}

- name: Install Ninja
run: choco install ninja

- name: Build
shell: cmd
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
cmake -S . -B build -G "Ninja Multi-Config" ^
-DGGML_NATIVE=OFF ^
-DGGML_CUDA=ON ^
-DLLAMA_BUILD_BORINGSSL=ON ^
-DCMAKE_CUDA_FLAGS="-diag-suppress=221" ^
${{ env.CMAKE_ARGS }}
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS%

- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name

- name: Pack artifacts
run: |
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\*

- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip

- name: Copy and pack Cuda runtime
run: |
echo "Cuda install location: ${{ env.CUDA_PATH }}"
$dst='.\build\bin\cudart\'
robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
robocopy "${{env.CUDA_PATH}}\bin\x64" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*

- name: Upload Cuda runtime
uses: actions/upload-artifact@v6
with:
path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip

release:
if: ${{ github.event.inputs.create_release == 'true' }}

permissions:
contents: write

runs-on: ubuntu-latest

needs:
- macOS-arm64
- linux-cuda
- windows-cuda

steps:
- name: Clone
uses: actions/checkout@v6
with:
fetch-depth: 0

- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name

- name: Download artifacts
uses: actions/download-artifact@v7
with:
path: ./artifact
merge-multiple: true

- name: Move artifacts
run: |
mkdir -p release
mv -v artifact/*.tar.gz release/ 2>/dev/null || true
mv -v artifact/*.zip release/ 2>/dev/null || true
ls -lh release/

- name: Create release
id: create_release
uses: ggml-org/action-create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.tag.outputs.name }}
body: |
Pre-built binaries (PrismML fork with Q1_0 1-bit quantization support).

**macOS:**
- [macOS Apple Silicon (arm64)](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz)

**Linux:**
- [Linux x64 (CUDA 12.4)](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-linux-cuda-12.4-x64.tar.gz)
- [Linux x64 (CUDA 12.8)](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-linux-cuda-12.8-x64.tar.gz)
- [Linux x64 (CUDA 13.1)](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-linux-cuda-13.1-x64.tar.gz)

**Windows:**
- [Windows x64 (CUDA 12.4)](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
- [Windows x64 (CUDA 13.1)](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip) - [CUDA 13.1 DLLs](https://github.com/${{ github.repository }}/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.1-x64.zip)

- name: Upload release
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
for file in release/*; do
echo "Uploading $(basename $file)..."
gh release upload ${{ steps.tag.outputs.name }} "$file" --clobber
done
1 change: 1 addition & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_TQ3_0,
};

static ggml_type kv_cache_type_from_str(const std::string & s) {
Expand Down
7 changes: 6 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,10 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_COUNT = 40,
GGML_TYPE_Q1_0 = 40,
GGML_TYPE_Q1_0_g128 = 41,
GGML_TYPE_TQ3_0 = 42, // TurboQuant 3-bit polar + QJL (no per-block scale)
GGML_TYPE_COUNT = 43,
};

// precision
Expand Down Expand Up @@ -463,6 +466,8 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_Q1_0 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_Q1_0_g128 = 27, // except 1d tensors
};

// available tensor operations:
Expand Down
36 changes: 36 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ typedef sycl::half2 ggml_half2;
// QR = QK / number of values before dequantization
// QI = number of 32 bit integers before dequantization

#define QI1_0 (QK1_0 / 32) // Number of int32s needed for QK1_0 bits (QK1_0/32)
#define QR1_0 1 // 1 bit per quantized element (matches the 1-bit nature of Q1_0)

#define QI1_0_g128 (QK1_0_g128 / 32) // Number of int32s needed for QK1_0_g128 bits (QK1_0_g128/32)
#define QR1_0_g128 1 // 1 bit per quantized element (matches the 1-bit nature of Q1_0_g128)


#define QI4_0 (QK4_0 / (4 * QR4_0))
#define QR4_0 2

Expand Down Expand Up @@ -167,6 +174,20 @@ typedef sycl::half2 ggml_half2;
#define GGML_EXTENSION __extension__
#endif // _MSC_VER

#define QK1_0 32 // MUST match QK8_0 for vec_dot computation! TODO see if we can do larger blocks later
typedef struct {
ggml_half d; // delta
uint8_t qs[QK1_0 / 8]; // bits / quants
} block_q1_0;
static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding");

#define QK1_0_g128 128
typedef struct {
ggml_half d; // delta
uint8_t qs[QK1_0_g128 / 8]; // bits / quants
} block_q1_0_g128;
static_assert(sizeof(block_q1_0_g128) == sizeof(ggml_half) + QK1_0_g128 / 8, "wrong q1_0_g128 block size/padding");

#define QK4_0 32
typedef struct {
ggml_half d; // delta
Expand Down Expand Up @@ -255,6 +276,21 @@ typedef struct {
} block_tq2_0;
static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");

// TurboQuant 3-bit quantization (3.5 bpw)
// Per TurboQuant paper (Algorithm 2: TurboQuant_prod), ICLR 2026
// Each block of 32 values is quantized as:
// - 2-bit MSE codebook indices (after random rotation Π·x)
// - 1-bit QJL residual signs (sign(S·r) where r = x - dequant_mse(quant_mse(x)))
// - FP16 residual norm ||r||₂ for QJL scaling
// Requires per-model rotation matrices Π and S (stored externally)
#define QK_TQ3_0 32
typedef struct {
uint8_t qs[QK_TQ3_0 / 4]; // 2-bit codebook indices, 32 × 2 bits = 8 bytes
uint8_t qr[QK_TQ3_0 / 8]; // QJL residual signs, 32 × 1 bit = 4 bytes
ggml_half gamma; // ||residual||₂ for QJL correction scaling
} block_tq3_0;
static_assert(sizeof(block_tq3_0) == QK_TQ3_0/4 + QK_TQ3_0/8 + sizeof(ggml_half), "wrong tq3_0 block size/padding");

//
// Super-block quantization structures
//
Expand Down
Loading