Skip to content

Commit f04e065

Browse files
committed
Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch
Text-only export of Gemma 4 31B-IT to ExecuTorch with the CUDA backend and INT4/INT8 weight quantization via a new packing-agnostic quant/ framework. The quant/ package separates quantization into four concerns: - recipe.py: declarative QuantRecipe with regex FQN matching - quantize.py: produces CanonicalQuantizedWeight (min_max, HQQ) - serialize.py: save/load to safetensors with versioned headers - pack.py + pack_cuda.py: per-module packer dispatch for CUDA Two production recipes: "default" (INT4 min_max + INT8 embedding) and "sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ elsewhere). Sliding window attention uses a ring-buffer KV cache (2x window size) for the 50 sliding layers, saving memory for long sequences. The 10 full-attention layers use a standard flat KV cache. Includes C++ runner (main.cpp), eager inference script, and 60+ unit and integration tests across quant/ and pipeline test files.
1 parent d8da621 commit f04e065

28 files changed

Lines changed: 4340 additions & 2 deletions

.github/workflows/cuda.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ jobs:
148148
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
149149
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="
150150
151+
# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
152+
python -m pytest examples/models/gemma4_31b/quant/ examples/models/gemma4_31b/test_pipeline.py examples/models/gemma4_31b/test_cuda_pipeline.py -v -o "addopts="
153+
151154
export-model-cuda-artifact:
152155
name: export-model-cuda-artifact
153156
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)

Makefile

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
#
9292
# ==============================================================================
9393

94-
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help
94+
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help
9595

9696
help:
9797
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
@@ -126,6 +126,7 @@ help:
126126
@echo " llava-cpu - Build Llava runner with CPU backend"
127127
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
128128
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
129+
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
129130
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
130131
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
131132
@echo " clean - Clean build artifacts"
@@ -425,6 +426,15 @@ qwen3_5_moe-cuda:
425426
@echo "✓ Build complete!"
426427
@echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner"
427428

429+
gemma4_31b-cuda:
430+
@echo "==> Building and installing ExecuTorch with CUDA..."
431+
cmake --workflow --preset llm-release-cuda
432+
@echo "==> Building Gemma 4 31B runner with CUDA..."
433+
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda
434+
@echo ""
435+
@echo "✓ Build complete!"
436+
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
437+
428438
qwen3_5_moe-metal:
429439
@echo "==> Building and installing ExecuTorch with Metal..."
430440
cmake --workflow --preset llm-release-metal

examples/models/gemma4/text_decoder/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,14 @@
66
# LICENSE file in the root directory of this source tree.
77

88
from .convert_weights import convert_hf_to_custom # noqa: F401
9+
from .gemma4_attention import ( # noqa: F401
10+
apply_rotary_emb,
11+
apply_rotary_emb_single,
12+
Gemma4KVCache,
13+
precompute_freqs_cis,
14+
rotate_half,
15+
)
916
from .gemma4_config import Gemma4Config # noqa: F401
17+
from .gemma4_decoder_layer import Gemma4MLP # noqa: F401
1018
from .gemma4_model import create_gemma4_model, Gemma4Model # noqa: F401
19+
from .gemma4_norm import RMSNorm, RMSNormNoWeight # noqa: F401

examples/models/gemma4/text_decoder/gemma4_norm.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,46 @@
55
# pyre-unsafe
66
# LICENSE file in the root directory of this source tree.
77

8+
"""Gemma 4 RMSNorm — self-contained re-implementation.
9+
10+
Numerically identical to ``transformers.models.gemma4.modeling_gemma4.Gemma4RMSNorm``
11+
(same float32 upcast and ``pow(mean_squared, -0.5)`` normalization), but
12+
without the transformers import so this module is exportable and dep-light.
13+
"""
14+
815
from functools import partial
916

10-
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm as RMSNorm
17+
import torch
18+
from torch import nn
19+
20+
21+
class RMSNorm(nn.Module):
22+
"""Gemma4 RMSNorm: ``y = (x / rms(x)) * weight``, computed in float32.
23+
24+
Unlike Gemma 2/3 (``(1 + weight)``) Gemma 4 multiplies by ``weight`` directly.
25+
Pass ``with_scale=False`` for the v-norm and the (unused-here) router norm,
26+
which omit the learnable weight entirely.
27+
"""
28+
29+
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
30+
super().__init__()
31+
self.eps = eps
32+
self.with_scale = with_scale
33+
if with_scale:
34+
self.weight = nn.Parameter(torch.ones(dim))
35+
36+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
37+
# Match transformers' use of pow(mean_squared, -0.5) over rsqrt;
38+
# the comment there cites Torch/JAX compiler differences.
39+
mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps
40+
return x * torch.pow(mean_squared, -0.5)
41+
42+
def forward(self, x: torch.Tensor) -> torch.Tensor:
43+
normed = self._norm(x.float())
44+
if self.with_scale:
45+
normed = normed * self.weight.float()
46+
return normed.type_as(x)
47+
1148

1249
# V-norm in attention uses RMSNorm without learnable weight.
1350
RMSNormNoWeight = partial(RMSNorm, with_scale=False)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.24)
8+
project(gemma4_31b)
9+
10+
set(CMAKE_CXX_STANDARD 17)
11+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
12+
13+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
14+
15+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
16+
17+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
18+
19+
# gflags
20+
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
21+
find_package(gflags REQUIRED)
22+
23+
# executorch
24+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
25+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
26+
executorch_target_link_options_shared_lib(executorch)
27+
28+
set(link_libraries executorch gflags)
29+
30+
# CPU ops (for the host-side helpers that aren't delegated to CUDA)
31+
list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
32+
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
33+
34+
# Extensions
35+
list(
36+
APPEND
37+
link_libraries
38+
extension_llm_runner
39+
extension_module
40+
extension_data_loader
41+
extension_tensor
42+
extension_flat_tensor
43+
)
44+
45+
# CUDA backend (the only supported backend for this example for now)
46+
if(EXECUTORCH_BUILD_CUDA)
47+
find_package(CUDAToolkit REQUIRED)
48+
list(APPEND link_libraries aoti_cuda_backend)
49+
executorch_target_link_options_shared_lib(aoti_cuda_backend)
50+
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
51+
else()
52+
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON")
53+
endif()
54+
55+
# Tokenizer (HuggingFace tokenizer.json)
56+
list(APPEND link_libraries tokenizers::tokenizers)
57+
58+
add_executable(gemma4_31b_runner main.cpp)
59+
target_include_directories(
60+
gemma4_31b_runner PUBLIC ${_common_include_directories}
61+
)
62+
target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries})
63+
64+
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
65+
target_link_options_gc_sections(gemma4_31b_runner)
66+
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
67+
endif()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
{
2+
"version": 6,
3+
"configurePresets": [
4+
{
5+
"name": "gemma4-31b-base",
6+
"hidden": true,
7+
"binaryDir": "${sourceDir}/../../../cmake-out/examples/models/gemma4_31b",
8+
"cacheVariables": {
9+
"CMAKE_BUILD_TYPE": "Release",
10+
"CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out",
11+
"CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out"
12+
}
13+
},
14+
{
15+
"name": "gemma4-31b-cuda",
16+
"displayName": "Gemma 4 31B runner (CUDA)",
17+
"inherits": ["gemma4-31b-base"],
18+
"cacheVariables": {
19+
"EXECUTORCH_BUILD_CUDA": "ON"
20+
},
21+
"condition": {
22+
"type": "inList",
23+
"string": "${hostSystemName}",
24+
"list": ["Linux", "Windows"]
25+
}
26+
}
27+
],
28+
"buildPresets": [
29+
{
30+
"name": "gemma4-31b-cuda",
31+
"displayName": "Build Gemma 4 31B runner (CUDA)",
32+
"configurePreset": "gemma4-31b-cuda",
33+
"targets": ["gemma4_31b_runner"]
34+
}
35+
],
36+
"workflowPresets": [
37+
{
38+
"name": "gemma4-31b-cuda",
39+
"displayName": "Configure and build Gemma 4 31B runner (CUDA)",
40+
"steps": [
41+
{
42+
"type": "configure",
43+
"name": "gemma4-31b-cuda"
44+
},
45+
{
46+
"type": "build",
47+
"name": "gemma4-31b-cuda"
48+
}
49+
]
50+
}
51+
]
52+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Gemma 4 31B-IT
2+
3+
Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8
4+
weight quantization. Currently supports the CUDA backend.
5+
6+
For architecture and design notes see [model.md](model.md).
7+
8+
## When to use which script
9+
10+
The full bf16 weights for 31B (~62 GB) often don't fit in available RAM. The
11+
recommended flow is to quantize once and reuse the quantized checkpoint for
12+
both export and eager inference:
13+
14+
| Script | Purpose | Peak memory |
15+
|---|---|---|
16+
| `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU |
17+
| `export.py --prequantized <dir>` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing |
18+
| `inference.py --prequantized <dir>` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU |
19+
| `export.py --model-dir <hf>` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing |
20+
21+
The quantized checkpoint is a safetensors file with int values + per-group
22+
scales and a JSON header describing each weight's `QuantConfig`. No tensor
23+
subclass or backend-specific packing — packing for the target backend happens
24+
at load time via `quant.pack_model()`.
25+
26+
## Quantization recipes
27+
28+
Two built-in recipes (see `quantize_and_save.py`):
29+
30+
| Recipe | Description |
31+
|---|---|
32+
| `default` | INT4 min_max linears, INT8 per-axis embedding |
33+
| `sensitive` | INT8 for edge-layer v_proj/down_proj, INT4 hqq elsewhere, INT8 per-axis embedding |
34+
35+
## Quantize once
36+
37+
```bash
38+
python examples/models/gemma4_31b/quantize_and_save.py \
39+
--model-dir ~/local/scripts/models/gemma-4-31B-it \
40+
--output ./gemma4_31b_int4 \
41+
--quant-recipe default
42+
```
43+
44+
Writes `model.safetensors`, `config.json`, and
45+
`tokenizer.json` into `--output`.
46+
47+
## Export to ExecuTorch
48+
49+
```bash
50+
python examples/models/gemma4_31b/export.py \
51+
--prequantized ./gemma4_31b_int4 \
52+
--output-dir ./gemma4_31b_exports \
53+
--max-seq-len 4096 \
54+
--backend cuda
55+
```
56+
57+
Writes `model.pte` and `model.ptd` into `--output-dir`.
58+
59+
## Eager inference
60+
61+
```bash
62+
python examples/models/gemma4_31b/inference.py \
63+
--prequantized ./gemma4_31b_int4 \
64+
--prompt "Write a short joke about saving RAM." \
65+
--max-new-tokens 128 \
66+
--temperature 0.8
67+
```
68+
69+
Useful before spending the export+lowering time to confirm the quantized
70+
model produces sensible text.
71+
72+
## Build the runner
73+
74+
```bash
75+
make gemma4_31b-cuda
76+
```
77+
78+
The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.
79+
80+
## Run the .pte
81+
82+
```bash
83+
./gemma4_31b_runner \
84+
--model_path ./gemma4_31b_exports/model.pte \
85+
--data_path ./gemma4_31b_exports/aoti_cuda_blob.ptd \
86+
--tokenizer_path ./gemma4_31b_int4/tokenizer.json \
87+
--prompt "Write a short joke about saving RAM." \
88+
--max_new_tokens 128 \
89+
--temperature 0.8
90+
```
91+
92+
For benchmarking, add `--cuda_graph` to capture the decode method in a CUDA
93+
graph (decode is fully static — `T=1`).
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)