Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ jobs:
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="
# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
python -m pytest examples/models/gemma4_31b/quant/ examples/models/gemma4_31b/test_pipeline.py examples/models/gemma4_31b/test_cuda_pipeline.py -v -o "addopts="
export-model-cuda-artifact:
name: export-model-cuda-artifact
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand Down Expand Up @@ -126,6 +126,7 @@ help:
@echo " llava-cpu - Build Llava runner with CPU backend"
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
@echo " clean - Clean build artifacts"
Expand Down Expand Up @@ -425,6 +426,15 @@ qwen3_5_moe-cuda:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner"

gemma4_31b-cuda:
@echo "==> Building and installing ExecuTorch with CUDA..."
cmake --workflow --preset llm-release-cuda
@echo "==> Building Gemma 4 31B runner with CUDA..."
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"

qwen3_5_moe-metal:
@echo "==> Building and installing ExecuTorch with Metal..."
cmake --workflow --preset llm-release-metal
Expand Down
9 changes: 9 additions & 0 deletions examples/models/gemma4/text_decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,14 @@
# LICENSE file in the root directory of this source tree.

from .convert_weights import convert_hf_to_custom # noqa: F401
from .gemma4_attention import ( # noqa: F401
apply_rotary_emb,
apply_rotary_emb_single,
Gemma4KVCache,
precompute_freqs_cis,
rotate_half,
)
from .gemma4_config import Gemma4Config # noqa: F401
from .gemma4_decoder_layer import Gemma4MLP # noqa: F401
from .gemma4_model import create_gemma4_model, Gemma4Model # noqa: F401
from .gemma4_norm import RMSNorm, RMSNormNoWeight # noqa: F401
39 changes: 38 additions & 1 deletion examples/models/gemma4/text_decoder/gemma4_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,46 @@
# pyre-unsafe
# LICENSE file in the root directory of this source tree.

"""Gemma 4 RMSNorm — self-contained re-implementation.

Numerically identical to ``transformers.models.gemma4.modeling_gemma4.Gemma4RMSNorm``
(same float32 upcast and ``pow(mean_squared, -0.5)`` normalization), but
without the transformers import so this module is exportable and dep-light.
"""

from functools import partial

from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm as RMSNorm
import torch
from torch import nn


class RMSNorm(nn.Module):
"""Gemma4 RMSNorm: ``y = (x / rms(x)) * weight``, computed in float32.

Unlike Gemma 2/3 (``(1 + weight)``) Gemma 4 multiplies by ``weight`` directly.
Pass ``with_scale=False`` for the v-norm and the (unused-here) router norm,
which omit the learnable weight entirely.
"""

def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
super().__init__()
self.eps = eps
self.with_scale = with_scale
if with_scale:
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x: torch.Tensor) -> torch.Tensor:
# Match transformers' use of pow(mean_squared, -0.5) over rsqrt;
# the comment there cites Torch/JAX compiler differences.
mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps
return x * torch.pow(mean_squared, -0.5)

def forward(self, x: torch.Tensor) -> torch.Tensor:
normed = self._norm(x.float())
if self.with_scale:
normed = normed * self.weight.float()
return normed.type_as(x)


# V-norm in attention uses RMSNorm without learnable weight.
RMSNormNoWeight = partial(RMSNorm, with_scale=False)
67 changes: 67 additions & 0 deletions examples/models/gemma4_31b/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.24)
project(gemma4_31b)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)

include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)

set(_common_include_directories ${EXECUTORCH_ROOT}/..)

# gflags
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
find_package(gflags REQUIRED)

# executorch
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
executorch_target_link_options_shared_lib(executorch)

set(link_libraries executorch gflags)

# CPU ops (for the host-side helpers that aren't delegated to CUDA)
list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)

# Extensions
list(
APPEND
link_libraries
extension_llm_runner
extension_module
extension_data_loader
extension_tensor
extension_flat_tensor
)

# CUDA backend (the only supported backend for this example for now)
if(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit REQUIRED)
list(APPEND link_libraries aoti_cuda_backend)
executorch_target_link_options_shared_lib(aoti_cuda_backend)
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
else()
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON")
endif()

# Tokenizer (HuggingFace tokenizer.json)
list(APPEND link_libraries tokenizers::tokenizers)

add_executable(gemma4_31b_runner main.cpp)
target_include_directories(
gemma4_31b_runner PUBLIC ${_common_include_directories}
)
target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries})

if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
target_link_options_gc_sections(gemma4_31b_runner)
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
endif()
52 changes: 52 additions & 0 deletions examples/models/gemma4_31b/CMakePresets.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"version": 6,
"configurePresets": [
{
"name": "gemma4-31b-base",
"hidden": true,
"binaryDir": "${sourceDir}/../../../cmake-out/examples/models/gemma4_31b",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out",
"CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out"
}
},
{
"name": "gemma4-31b-cuda",
"displayName": "Gemma 4 31B runner (CUDA)",
"inherits": ["gemma4-31b-base"],
"cacheVariables": {
"EXECUTORCH_BUILD_CUDA": "ON"
},
"condition": {
"type": "inList",
"string": "${hostSystemName}",
"list": ["Linux", "Windows"]
}
}
],
"buildPresets": [
{
"name": "gemma4-31b-cuda",
"displayName": "Build Gemma 4 31B runner (CUDA)",
"configurePreset": "gemma4-31b-cuda",
"targets": ["gemma4_31b_runner"]
}
],
"workflowPresets": [
{
"name": "gemma4-31b-cuda",
"displayName": "Configure and build Gemma 4 31B runner (CUDA)",
"steps": [
{
"type": "configure",
"name": "gemma4-31b-cuda"
},
{
"type": "build",
"name": "gemma4-31b-cuda"
}
]
}
]
}
112 changes: 112 additions & 0 deletions examples/models/gemma4_31b/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Gemma 4 31B-IT

Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8
weight quantization. Currently supports the CUDA backend.

For architecture and design notes see [model.md](model.md).

## When to use which script

The full bf16 weights for 31B (~62 GB) often don't fit in available RAM. The
recommended flow is to quantize once and reuse the quantized checkpoint for
both export and eager inference:

| Script | Purpose | Peak memory |
|---|---|---|
| `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU |
| `export.py --prequantized <dir>` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing |
| `inference.py --prequantized <dir>` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU |
| `export.py --model-dir <hf>` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing |

The quantized checkpoint is a safetensors file with int values + per-group
scales and a JSON header describing each weight's `QuantConfig`. No tensor
subclass or backend-specific packing — packing for the target backend happens
at load time via `quant.pack_model()`.

## Quantization recipes

Two built-in recipes (see `quantize_and_save.py`):

| Recipe | Description |
|---|---|
| `default` | INT4 min_max linears, INT8 per-axis embedding |
| `sensitive` | INT8 for edge-layer v_proj/down_proj, INT4 hqq elsewhere, INT8 per-axis embedding |

## Prequantized checkpoint

A prequantized checkpoint (sensitive recipe) is available on HuggingFace:

```bash
huggingface-cli download SocialLocalMobile/gemma-4-31B-it-HQQ-INT4 --local-dir gemma-4-31B-it-HQQ-INT4
```

> **Note**: This checkpoint is intended for development and testing of the
> ExecuTorch CUDA export pipeline. Output quality has not been formally
> evaluated against the base model.

Use it directly with `--prequantized` in the export and inference scripts
below — no need to run `quantize_and_save.py`.

## Quantize from scratch (optional)

To quantize from the original bf16 checkpoint instead, pass
`--quant-recipe` to select a recipe (`default` or `sensitive`):

```bash
python examples/models/gemma4_31b/quantize_and_save.py \
--model-dir /path/to/gemma-4-31B-it \
--output ./gemma4_31b_int4 \
--quant-recipe sensitive
```

See [Quantization recipes](#quantization-recipes) above for details on each
recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into
`--output`.

## Export to ExecuTorch

```bash
python examples/models/gemma4_31b/export.py \
--prequantized ./gemma4_31b_int4 \
--output-dir ./gemma4_31b_exports \
--max-seq-len 4096 \
--backend cuda
```

Writes `model.pte` and `model.ptd` into `--output-dir`.

## Eager inference

```bash
python examples/models/gemma4_31b/inference.py \
--prequantized ./gemma4_31b_int4 \
--prompt "Write a short joke about saving RAM." \
--max-new-tokens 128 \
--temperature 0.8
```

Useful before spending the export+lowering time to confirm the quantized
model produces sensible text.

## Build the runner

```bash
make gemma4_31b-cuda
```

The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.

## Run the .pte

```bash
./gemma4_31b_runner \
--model_path ./gemma4_31b_exports/model.pte \
--data_path ./gemma4_31b_exports/aoti_cuda_blob.ptd \
--tokenizer_path ./gemma4_31b_int4/tokenizer.json \
--prompt "Write a short joke about saving RAM." \
--max_new_tokens 128 \
--temperature 0.8
```

For benchmarking, add `--cuda_graph` to capture the decode method in a CUDA
graph (decode is fully static — `T=1`).
5 changes: 5 additions & 0 deletions examples/models/gemma4_31b/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Loading
Loading