Skip to content

Commit 9e36d62

Browse files
mergennachinmnachin
andauthored
Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch (pytorch#19213)
Text-only export of Gemma 4 31B-IT to ExecuTorch with INT4/INT8 weight quantization. Quantized weights use torchao's native tensor subclasses (Int4Tensor, IntxUnpackedToInt8Tensor) for serialization, aligning with the torchao ecosystem. quant/ package separates quantization into independent modules: - recipe.py: declarative QuantRecipe with regex FQN matching and per-layer overrides - quantize.py: quantize_weight / dequantize_weight / quantize_model — returns torchao subclasses directly. 8-bit fully delegates to IntxUnpackedToInt8Tensor.from_hp (min_max and HQQ). 4-bit uses torchao primitives + manual Int4Tensor construction (pending mslk availability for from_hp) - pack.py: pack_model (bulk, groups by parent for MoE) and pack_one (streaming). Dispatches via isinstance(_, TorchAOBaseTensor) - pack_cuda.py: converts Int4Tensor to IntxUnpackedToInt8Tensor (int4 values unpacked to int8) and passes INT8 IntxUnpackedToInt8Tensor through unchanged. No CUDA required for packing — the CUDA-specific tinygemm conversion is a source transform applied at export time - gguf.py: unpack Q4_K/Q6_K GGUF blocks directly to Int4Tensor/IntxUnpackedToInt8Tensor, with streaming iterator Serialization uses torchao's safetensors integration (torchao.prototype.safetensors) — no custom format. Checkpoints are compatible with torchao's save_pretrained/load_pretrained and can be loaded by vLLM. This framework is designed to be promoted and reused for Qwen 3.5 MoE and other models — adding a new model requires only a QuantRecipe and optionally a custom packer. Quantization recipes: "default" (INT4 min_max linears + INT8 per-axis embedding) and "sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ asymmetric elsewhere). Dual-path INT4 linear dispatch: IntxUnpackedToInt8Tensor's F.linear dispatch dequantizes to bf16 and calls cuBLAS, optimal for prefill (12x faster than tinygemm at T=2048). For decode, a model-agnostic source transform (backends/cuda/transforms/int4_linear_dispatch.py) converts to Int4TilePackedTo4dTensor (tinygemm), optimal for M=1. Export flow: prefill first (dequant+cuBLAS), then tinygemm transform, then decode export. inference.py applies the tinygemm transform for fast eager decode. Split-K flash-decoding: ReplaceEdgeOpWithTritonOpPass in the CUDA backend selects triton::sdpa_decode_splitk for SDPA nodes where L_q=1 and L_kv exceeds 2048. At 128K context, full-attention decode SDPA improves from 15.7ms/layer to 0.7ms/layer (22x). Sliding-window layers (ring buffer <= 2048) use standard triton::sdpa. No model code changes — the pass inspects Q/K shapes in the exported graph automatically. GGUF support: inference.py --gguf and export.py --gguf load community-quantized GGUF files directly. Tied embed/lm_head is untied — embedding dequantized to bf16 for gather, lm_head keeps INT4 for matmul. Ring-buffer KV cache: Sliding window layers use RingKVCache (2x window) instead of flat max_seq_len buffers. The C++ runner chunks long prompts automatically via get_max_prefill_chunk metadata. Chunked prefill produces identical logits to sequential (verified by test). Includes: C++ runner with BOS/EOS handling, chunked prefill, and #ifdef guards for non-CUDA builds; eager inference with torch.compile; unit and integration tests across quant/tests/, tests/, and backends/cuda/tests/. ``` ┌──────────────────┬────────────────────┐ │ Metric │ Value │ ├──────────────────┼────────────────────┤ │ Prompt tokens │ 513 │ ├──────────────────┼────────────────────┤ │ Generated tokens │ 128 │ ├──────────────────┼────────────────────┤ │ Prefill │ 766 tok/s (670ms) │ ├──────────────────┼────────────────────┤ │ Decode │ 21.5 tok/s │ ├──────────────────┼────────────────────┤ │ TTFT │ 89ms │ ├──────────────────┼────────────────────┤ │ GPU peak │ 25.1GB │ ├──────────────────┼────────────────────┤ │ Model load │ 28.8s │ └──────────────────┴────────────────────┘ ``` --------- Co-authored-by: mnachin <mnachin@fb.com>
1 parent aa0d465 commit 9e36d62

41 files changed

Lines changed: 6404 additions & 3 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/cuda.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ jobs:
148148
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
149149
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="
150150
151+
# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
152+
pip install gguf
153+
python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts="
154+
151155
export-model-cuda-artifact:
152156
name: export-model-cuda-artifact
153157
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)

Makefile

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
#
9292
# ==============================================================================
9393

94-
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help
94+
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help
9595

9696
help:
9797
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
@@ -126,6 +126,7 @@ help:
126126
@echo " llava-cpu - Build Llava runner with CPU backend"
127127
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
128128
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
129+
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
129130
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
130131
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
131132
@echo " clean - Clean build artifacts"
@@ -425,6 +426,15 @@ qwen3_5_moe-cuda:
425426
@echo "✓ Build complete!"
426427
@echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner"
427428

429+
gemma4_31b-cuda:
430+
@echo "==> Building and installing ExecuTorch with CUDA..."
431+
cmake --workflow --preset llm-release-cuda
432+
@echo "==> Building Gemma 4 31B runner with CUDA..."
433+
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda
434+
@echo ""
435+
@echo "✓ Build complete!"
436+
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
437+
428438
qwen3_5_moe-metal:
429439
@echo "==> Building and installing ExecuTorch with Metal..."
430440
cmake --workflow --preset llm-release-metal

backends/cuda/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
110110
# Only build CUDA shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112112
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113-
runtime/shims/sort.cu runtime/shims/rand.cu
113+
runtime/shims/int4_plain_mm.cu runtime/shims/sort.cu
114+
runtime/shims/rand.cu
114115
)
115116
endif()
116117

backends/cuda/cuda_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
226226
"at::_ops::_weight_int4pack_mm::call": None,
227227
"at::_ops::sort_stable::call": None,
228228
"aoti_torch_cuda_randint_low_out": None,
229+
"executorch_cuda::int4_plain_mm": None,
230+
"aoti_torch_cuda_int4_plain_mm": None,
229231
}
230232

231233
@classmethod
@@ -298,6 +300,20 @@ def get_aoti_compile_options(
298300
"aot_inductor.emit_multi_arch_kernel": emit_multi_arch_kernel,
299301
}
300302

303+
try:
304+
import torch
305+
306+
options["aot_inductor.custom_ops_to_c_shims"] = {
307+
torch.ops.executorch_cuda.int4_plain_mm.default: [
308+
"AOTITorchError aoti_torch_cuda_int4_plain_mm("
309+
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
310+
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
311+
],
312+
}
313+
except AttributeError:
314+
# int4_dispatch.py not imported — op not registered, skip C shim mapping
315+
pass
316+
301317
# Parse compile_specs to check for platform
302318

303319
platform = "linux"

backends/cuda/int4_dispatch.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time.
8+
9+
This module overrides Int4Tensor's F.linear dispatch so that torch.export
10+
traces through our custom op and dequant logic instead of torchao's default
11+
(mslk/tinygemm). The code here executes during eager inference and during
12+
AOTI export tracing — it does NOT run at .pte runtime.
13+
14+
At .pte runtime, the captured graph is executed by the AOTI-generated .so:
15+
- The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that
16+
runs the W4A8 dp4a matvec kernel (backends/cuda/runtime/shims/).
17+
- The inline dequant + F.linear is compiled by inductor into fused Triton
18+
dequant + cuBLAS matmul kernels.
19+
20+
Dispatch strategy (determines what gets captured in the export graph):
21+
Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm``
22+
Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops)
23+
24+
Import this module before using nn.Linear with Int4Tensor weights::
25+
26+
import executorch.backends.cuda.int4_dispatch # noqa: F401
27+
"""
28+
29+
import torch
30+
import torch.nn.functional as F
31+
from torch.library import impl, Library
32+
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
33+
34+
# ---------------------------------------------------------------------------
35+
# Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager
36+
# ---------------------------------------------------------------------------
37+
38+
_lib = Library("executorch_cuda", "DEF")
39+
_lib.define(
40+
"int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor"
41+
)
42+
43+
44+
@impl(_lib, "int4_plain_mm", "Meta")
45+
def _meta(self, qdata, scale, zero, group_size):
46+
return torch.empty(
47+
self.shape[0], qdata.shape[0], dtype=self.dtype, device=self.device
48+
)
49+
50+
51+
@impl(_lib, "int4_plain_mm", "CUDA")
52+
def _cuda(self, qdata, scale, zero, group_size):
53+
return _dequant_matmul(self, qdata, scale, zero, group_size)
54+
55+
56+
def _dequant_matmul(x, qdata, scale, zero, group_size):
57+
"""Dequant INT4 weights to input dtype and call F.linear."""
58+
N, K_half = qdata.shape
59+
K = K_half * 2
60+
n_groups = K // group_size
61+
gs_half = group_size // 2
62+
dtype = x.dtype
63+
64+
p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half)
65+
low = (p & 0x0F).to(dtype)
66+
high = ((p >> 4) & 0x0F).to(dtype)
67+
data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size)
68+
69+
s = scale.to(dtype).t().unsqueeze(-1)
70+
z = zero.to(dtype).t().unsqueeze(-1)
71+
w_deq = ((data - z) * s).reshape(N, K)
72+
73+
return F.linear(x, w_deq)
74+
75+
76+
# ---------------------------------------------------------------------------
77+
# Int4Tensor F.linear dispatch
78+
# ---------------------------------------------------------------------------
79+
80+
aten = torch.ops.aten
81+
_implements = Int4Tensor.implements
82+
_implements_torch_function = Int4Tensor.implements_torch_function
83+
84+
85+
@_implements([aten.linear.default])
86+
@_implements_torch_function([F.linear])
87+
def _(func, types, args, kwargs):
88+
input_tensor = args[0]
89+
weight_tensor = args[1]
90+
bias = args[2] if len(args) > 2 else None
91+
92+
orig_shape = input_tensor.shape
93+
x_2d = input_tensor.reshape(-1, orig_shape[-1])
94+
95+
qdata = weight_tensor.qdata
96+
scale = weight_tensor.scale
97+
zero = weight_tensor.zero_point
98+
gs = weight_tensor.block_size[-1]
99+
100+
M = x_2d.shape[0]
101+
if M <= 4:
102+
out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs)
103+
else:
104+
out = _dequant_matmul(x_2d, qdata, scale, zero, gs)
105+
106+
out = out.reshape(*orig_shape[:-1], -1)
107+
if bias is not None:
108+
out = out + bias
109+
return out
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cuda.h>
10+
#include <cuda_runtime.h>
11+
12+
#include <executorch/backends/aoti/utils.h>
13+
#include <executorch/backends/cuda/runtime/shims/int4_plain_mm.h>
14+
#include <executorch/backends/cuda/runtime/shims/int4_plain_mm.cuh>
15+
#include <executorch/backends/cuda/runtime/shims/memory.h>
16+
#include <executorch/runtime/platform/log.h>
17+
18+
namespace executorch::backends::cuda {
19+
#ifdef __cplusplus
20+
extern "C" {
21+
#endif
22+
23+
AOTITorchError aoti_torch_cuda_int4_plain_mm(
24+
Tensor* self,
25+
Tensor* qdata,
26+
Tensor* scale,
27+
Tensor* zero,
28+
int64_t group_size,
29+
Tensor** ret0) {
30+
ET_CHECK_OR_RETURN_ERROR(
31+
self != nullptr,
32+
InvalidArgument,
33+
"aoti_torch_cuda_int4_plain_mm: self is null");
34+
35+
ET_CHECK_OR_RETURN_ERROR(
36+
qdata != nullptr,
37+
InvalidArgument,
38+
"aoti_torch_cuda_int4_plain_mm: qdata is null");
39+
40+
ET_CHECK_OR_RETURN_ERROR(
41+
scale != nullptr,
42+
InvalidArgument,
43+
"aoti_torch_cuda_int4_plain_mm: scale is null");
44+
45+
ET_CHECK_OR_RETURN_ERROR(
46+
zero != nullptr,
47+
InvalidArgument,
48+
"aoti_torch_cuda_int4_plain_mm: zero is null");
49+
50+
ET_CHECK_OR_RETURN_ERROR(
51+
ret0 != nullptr,
52+
InvalidArgument,
53+
"aoti_torch_cuda_int4_plain_mm: ret0 is null");
54+
55+
int32_t M = self->size(0);
56+
int32_t N = qdata->size(0);
57+
Tensor* C = nullptr;
58+
std::array<int64_t, 2> c_shape = {M, N};
59+
std::array<int64_t, 2> c_stride = {N, 1};
60+
aoti_torch_empty_strided(
61+
2,
62+
c_shape.data(),
63+
c_stride.data(),
64+
static_cast<int32_t>(
65+
executorch::backends::aoti::slim::c10::ScalarType::BFloat16),
66+
static_cast<int32_t>(
67+
executorch::backends::aoti::slim::c10::DeviceType::CUDA),
68+
0,
69+
&C);
70+
71+
_int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C);
72+
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();
73+
74+
*ret0 = C;
75+
return Error::Ok;
76+
}
77+
78+
#ifdef __cplusplus
79+
}
80+
#endif
81+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)