diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index c3b7c058ee6..d1b954820ef 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -148,6 +148,9 @@ jobs: # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler) python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" + # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) + python -m pytest examples/models/gemma4_31b/quant/ examples/models/gemma4_31b/test_pipeline.py examples/models/gemma4_31b/test_cuda_pipeline.py -v -o "addopts=" + export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/Makefile b/Makefile index 3c0eac14bce..ba61dddce44 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -126,6 +126,7 @@ help: @echo " llava-cpu - Build Llava runner with CPU backend" @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" + @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -425,6 +426,15 @@ qwen3_5_moe-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner" +gemma4_31b-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Gemma 4 31B runner with CUDA..." + cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/examples/models/gemma4/text_decoder/__init__.py b/examples/models/gemma4/text_decoder/__init__.py index 25d7c5c7a16..51c96f0717f 100644 --- a/examples/models/gemma4/text_decoder/__init__.py +++ b/examples/models/gemma4/text_decoder/__init__.py @@ -6,5 +6,14 @@ # LICENSE file in the root directory of this source tree. from .convert_weights import convert_hf_to_custom # noqa: F401 +from .gemma4_attention import ( # noqa: F401 + apply_rotary_emb, + apply_rotary_emb_single, + Gemma4KVCache, + precompute_freqs_cis, + rotate_half, +) from .gemma4_config import Gemma4Config # noqa: F401 +from .gemma4_decoder_layer import Gemma4MLP # noqa: F401 from .gemma4_model import create_gemma4_model, Gemma4Model # noqa: F401 +from .gemma4_norm import RMSNorm, RMSNormNoWeight # noqa: F401 diff --git a/examples/models/gemma4/text_decoder/gemma4_norm.py b/examples/models/gemma4/text_decoder/gemma4_norm.py index 17e42a43ca1..2c8fec67525 100644 --- a/examples/models/gemma4/text_decoder/gemma4_norm.py +++ b/examples/models/gemma4/text_decoder/gemma4_norm.py @@ -5,9 +5,46 @@ # pyre-unsafe # LICENSE file in the root directory of this source tree. +"""Gemma 4 RMSNorm — self-contained re-implementation. + +Numerically identical to ``transformers.models.gemma4.modeling_gemma4.Gemma4RMSNorm`` +(same float32 upcast and ``pow(mean_squared, -0.5)`` normalization), but +without the transformers import so this module is exportable and dep-light. +""" + from functools import partial -from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm as RMSNorm +import torch +from torch import nn + + +class RMSNorm(nn.Module): + """Gemma4 RMSNorm: ``y = (x / rms(x)) * weight``, computed in float32. + + Unlike Gemma 2/3 (``(1 + weight)``) Gemma 4 multiplies by ``weight`` directly. + Pass ``with_scale=False`` for the v-norm and the (unused-here) router norm, + which omit the learnable weight entirely. + """ + + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + # Match transformers' use of pow(mean_squared, -0.5) over rsqrt; + # the comment there cites Torch/JAX compiler differences. + mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps + return x * torch.pow(mean_squared, -0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + normed = self._norm(x.float()) + if self.with_scale: + normed = normed * self.weight.float() + return normed.type_as(x) + # V-norm in attention uses RMSNorm without learnable weight. RMSNormNoWeight = partial(RMSNorm, with_scale=False) diff --git a/examples/models/gemma4_31b/CMakeLists.txt b/examples/models/gemma4_31b/CMakeLists.txt new file mode 100644 index 00000000000..8d536a47fc5 --- /dev/null +++ b/examples/models/gemma4_31b/CMakeLists.txt @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(gemma4_31b) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# executorch +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# CPU ops (for the host-side helpers that aren't delegated to CUDA) +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# Extensions +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# CUDA backend (the only supported backend for this example for now) +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + add_compile_definitions(EXECUTORCH_BUILD_CUDA) +else() + message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON") +endif() + +# Tokenizer (HuggingFace tokenizer.json) +list(APPEND link_libraries tokenizers::tokenizers) + +add_executable(gemma4_31b_runner main.cpp) +target_include_directories( + gemma4_31b_runner PUBLIC ${_common_include_directories} +) +target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries}) + +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(gemma4_31b_runner) + target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s") +endif() diff --git a/examples/models/gemma4_31b/CMakePresets.json b/examples/models/gemma4_31b/CMakePresets.json new file mode 100644 index 00000000000..97ba7f4c57a --- /dev/null +++ b/examples/models/gemma4_31b/CMakePresets.json @@ -0,0 +1,52 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "gemma4-31b-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/gemma4_31b", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "gemma4-31b-cuda", + "displayName": "Gemma 4 31B runner (CUDA)", + "inherits": ["gemma4-31b-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + } + ], + "buildPresets": [ + { + "name": "gemma4-31b-cuda", + "displayName": "Build Gemma 4 31B runner (CUDA)", + "configurePreset": "gemma4-31b-cuda", + "targets": ["gemma4_31b_runner"] + } + ], + "workflowPresets": [ + { + "name": "gemma4-31b-cuda", + "displayName": "Configure and build Gemma 4 31B runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "gemma4-31b-cuda" + }, + { + "type": "build", + "name": "gemma4-31b-cuda" + } + ] + } + ] +} diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md new file mode 100644 index 00000000000..3dcb958d8cc --- /dev/null +++ b/examples/models/gemma4_31b/README.md @@ -0,0 +1,112 @@ +# Gemma 4 31B-IT + +Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8 +weight quantization. Currently supports the CUDA backend. + +For architecture and design notes see [model.md](model.md). + +## When to use which script + +The full bf16 weights for 31B (~62 GB) often don't fit in available RAM. The +recommended flow is to quantize once and reuse the quantized checkpoint for +both export and eager inference: + +| Script | Purpose | Peak memory | +|---|---|---| +| `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU | +| `export.py --prequantized ` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing | +| `inference.py --prequantized ` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU | +| `export.py --model-dir ` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing | + +The quantized checkpoint is a safetensors file with int values + per-group +scales and a JSON header describing each weight's `QuantConfig`. No tensor +subclass or backend-specific packing — packing for the target backend happens +at load time via `quant.pack_model()`. + +## Quantization recipes + +Two built-in recipes (see `quantize_and_save.py`): + +| Recipe | Description | +|---|---| +| `default` | INT4 min_max linears, INT8 per-axis embedding | +| `sensitive` | INT8 for edge-layer v_proj/down_proj, INT4 hqq elsewhere, INT8 per-axis embedding | + +## Prequantized checkpoint + +A prequantized checkpoint (sensitive recipe) is available on HuggingFace: + +```bash +huggingface-cli download SocialLocalMobile/gemma-4-31B-it-HQQ-INT4 --local-dir gemma-4-31B-it-HQQ-INT4 +``` + +> **Note**: This checkpoint is intended for development and testing of the +> ExecuTorch CUDA export pipeline. Output quality has not been formally +> evaluated against the base model. + +Use it directly with `--prequantized` in the export and inference scripts +below — no need to run `quantize_and_save.py`. + +## Quantize from scratch (optional) + +To quantize from the original bf16 checkpoint instead, pass +`--quant-recipe` to select a recipe (`default` or `sensitive`): + +```bash +python examples/models/gemma4_31b/quantize_and_save.py \ + --model-dir /path/to/gemma-4-31B-it \ + --output ./gemma4_31b_int4 \ + --quant-recipe sensitive +``` + +See [Quantization recipes](#quantization-recipes) above for details on each +recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into +`--output`. + +## Export to ExecuTorch + +```bash +python examples/models/gemma4_31b/export.py \ + --prequantized ./gemma4_31b_int4 \ + --output-dir ./gemma4_31b_exports \ + --max-seq-len 4096 \ + --backend cuda +``` + +Writes `model.pte` and `model.ptd` into `--output-dir`. + +## Eager inference + +```bash +python examples/models/gemma4_31b/inference.py \ + --prequantized ./gemma4_31b_int4 \ + --prompt "Write a short joke about saving RAM." \ + --max-new-tokens 128 \ + --temperature 0.8 +``` + +Useful before spending the export+lowering time to confirm the quantized +model produces sensible text. + +## Build the runner + +```bash +make gemma4_31b-cuda +``` + +The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. + +## Run the .pte + +```bash +./gemma4_31b_runner \ + --model_path ./gemma4_31b_exports/model.pte \ + --data_path ./gemma4_31b_exports/aoti_cuda_blob.ptd \ + --tokenizer_path ./gemma4_31b_int4/tokenizer.json \ + --prompt "Write a short joke about saving RAM." \ + --max_new_tokens 128 \ + --temperature 0.8 +``` + +For benchmarking, add `--cuda_graph` to capture the decode method in a CUDA +graph (decode is fully static — `T=1`). diff --git a/examples/models/gemma4_31b/__init__.py b/examples/models/gemma4_31b/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/models/gemma4_31b/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py new file mode 100644 index 00000000000..f2bf054015e --- /dev/null +++ b/examples/models/gemma4_31b/export.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Export Gemma 4 31B-IT to ExecuTorch (.pte + .ptd). + +Two methods are exported and lowered together so they share KV-cache buffers: + - "decode": T=1, static shape, returns the next sampled token. + - "prefill": T>=2, dynamic shape, returns the next sampled token. + +Two input paths: + --prequantized Load a quantized checkpoint (from quantize_and_save.py) + and pack for the target backend. No re-quantization. + --model-dir Load bf16 checkpoint, quantize, pack, and export + in one shot. + +Backends: + --backend cuda (default) CUDA via tinygemm INT4 + CudaPartitioner. +""" + +import argparse +import os + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + Gemma4_31BConfig, + materialize_runtime_buffers, +) + + +# --------------------------------------------------------------------------- +# Load paths + + +def load_prequantized_model( + prequantized_dir: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple[Gemma4_31B, Gemma4_31BConfig]: + """Load a quantized checkpoint and pack for the target backend.""" + config = Gemma4_31BConfig.from_hf_config( + os.path.join(prequantized_dir, "config.json") + ) + config.max_seq_len = max_seq_len + + print("Building model on meta device...") + with torch.device("meta"): + model = Gemma4_31B(config) + + safetensors_path = os.path.join(prequantized_dir, "model.safetensors") + print(f"Loading quantized checkpoint from {safetensors_path}...") + _pack_for_backend(model, safetensors_path, backend) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config + + +def load_and_quantize( + model_dir: str, + recipe_name: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple[Gemma4_31B, Gemma4_31BConfig]: + """Load bf16 checkpoint, quantize, pack — one shot.""" + from executorch.examples.models.gemma4_31b.quant import pack_model, quantize_model + from executorch.examples.models.gemma4_31b.quantize_and_save import _RECIPES + + recipe = _RECIPES[recipe_name] + + print("Loading checkpoint (lazy, shard-by-shard)...") + model, config = Gemma4_31B.from_hf_checkpoint(model_dir, max_seq_len=max_seq_len) + + if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr(): + print("Untying embed_tokens / lm_head...") + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + print(f"Quantizing with recipe '{recipe_name}'...") + quantized, unquantized = quantize_model(model, recipe) + + print(f"Packing for {backend}...") + with torch.device("meta"): + model = Gemma4_31B(config) + pack_model(model, quantized, unquantized, packers=_get_packers(backend)) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config + + +# --------------------------------------------------------------------------- +# Backend dispatch helpers + + +def _get_packers(backend: str) -> dict: + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS + + return DEFAULT_CUDA_PACKERS + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda + + load_and_pack_for_cuda(path, model) + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +# --------------------------------------------------------------------------- +# Export + lower + + +def export_and_lower( + model: Gemma4_31B, + config: Gemma4_31BConfig, + output_dir: str, + backend: str = "cuda", +) -> None: + """Export and lower the model to ExecuTorch for the given backend.""" + if backend == "cuda": + _export_cuda(model, config, output_dir) + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: + import torch._inductor.config as inductor_config + + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, + ) + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + inductor_config.coordinate_descent_tuning = False + inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + print("Exporting decode (T=1)...") + with torch.no_grad(): + decode_ep = export( + model, + ( + torch.tensor([[0]], dtype=torch.long), + torch.tensor([0], dtype=torch.long), + torch.tensor([1.0], dtype=torch.float32), + ), + strict=True, + ) + + # Cap prefill length to the ring-buffer KV cache size (2×sliding_window). + # Longer prompts are chunked by the runner. + max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) + seq_dim = Dim("seq_len", min=2, max=max_prefill) + print(f"Exporting prefill (T in [2, {max_prefill}])...") + with torch.no_grad(): + prefill_ep = export( + model, + ( + torch.zeros((1, max_prefill), dtype=torch.long), + torch.arange(max_prefill, dtype=torch.long), + torch.tensor([1.0], dtype=torch.float32), + ), + dynamic_shapes=({1: seq_dim}, {0: seq_dim}, None), + strict=True, + ) + + print("Lowering to ExecuTorch with CUDA backend...") + et_prog = to_edge_transform_and_lower( + {"decode": decode_ep, "prefill": prefill_ep}, + partitioner={ + "decode": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("decode")] + ) + ], + "prefill": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("prefill")] + ) + ], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods={ + "get_max_seq_len": config.max_seq_len, + "get_vocab_size": config.vocab_size, + "get_n_layers": config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": False, + "enable_dynamic_shape": True, + }, + ) + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + share_mutable_buffers=True, + ), + emit_mutable_buffer_names=True, + ), + ) + + os.makedirs(output_dir, exist_ok=True) + pte_path = os.path.join(output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB") + + if et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + print(f" Saved tensor data (.ptd) to {output_dir}/") + print("Done.") + + +# --------------------------------------------------------------------------- +# CLI + + +def main() -> None: + from executorch.examples.models.gemma4_31b.quantize_and_save import _RECIPES + + parser = argparse.ArgumentParser(description="Export Gemma 4 31B-IT to ExecuTorch.") + src = parser.add_mutually_exclusive_group(required=True) + src.add_argument( + "--model-dir", + default=None, + help="HuggingFace model dir. Triggers load + quantize + export.", + ) + src.add_argument( + "--prequantized", + default=None, + help="Path to a quantized checkpoint directory. Skips quantization.", + ) + parser.add_argument( + "--output-dir", + default="./gemma4_31b_exports", + help="Output directory for model.pte / model.ptd.", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=4096, + help="KV cache size.", + ) + parser.add_argument( + "--quant-recipe", + default="default", + choices=list(_RECIPES), + help="Quantization recipe (only with --model-dir).", + ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend for export.", + ) + args = parser.parse_args() + + if args.backend == "cuda" and not torch.cuda.is_available(): + parser.error("CUDA is required for the cuda backend.") + + if args.prequantized: + model, config = load_prequantized_model( + args.prequantized, + max_seq_len=args.max_seq_len, + backend=args.backend, + ) + else: + model, config = load_and_quantize( + args.model_dir, + args.quant_recipe, + max_seq_len=args.max_seq_len, + backend=args.backend, + ) + + export_and_lower(model, config, args.output_dir, backend=args.backend) + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py new file mode 100644 index 00000000000..59418f3b746 --- /dev/null +++ b/examples/models/gemma4_31b/inference.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Eager inference on a prequantized Gemma 4 31B-IT model (CUDA + torch.compile). + +Loads a quantized checkpoint (from ``quantize_and_save.py``), packs for CUDA, +materializes runtime buffers, optionally compiles with ``torch.compile``, and +generates text autoregressively. The model performs Gumbel-max sampling +on-device, so each forward returns the next token ID as a float tensor of +shape ``[B, 1]``. + +Usage: + python inference.py \\ + --prequantized ./gemma4_31b_int4 \\ + --prompt "Write a short joke about saving RAM." \\ + --max-new-tokens 128 \\ + --temperature 0.8 +""" + +import argparse +import os +import time + +import torch + +from executorch.examples.models.gemma4_31b.export import load_prequantized_model +from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers + + +def _move_to_cuda(model, config) -> None: + """Move the prequantized model to CUDA and materialize runtime buffers there. + + Parameters are moved individually (not via ``model.cuda()``) to preserve + ``Int4TilePackedTo4dTensor`` subclass identity. Non-meta buffers (e.g. + ``layer_scalar``) are moved to CUDA. Meta-device buffers (KV cache, RoPE, + constants) are materialized directly on CUDA via + ``materialize_runtime_buffers``. + """ + for name, p in model.named_parameters(): + parts = name.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + setattr( + parent, + parts[-1], + torch.nn.Parameter(p.data.to("cuda"), requires_grad=False), + ) + + for fqn, buf in list(model.named_buffers()): + if buf.device.type != "meta": + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + parent.register_buffer(parts[-1], buf.to("cuda"), persistent=False) + + materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda") + + +def generate( + model, + tokenizer, + prompt: str, + max_new_tokens: int = 128, + temperature: float = 0.0, + eos_token_ids=None, + bos_token_id: int = 2, +) -> str: + """Autoregressive generation. Prefill is one-token-at-a-time so a single + compiled graph handles every step; the exported PTE uses a separate + multi-token prefill method, but for eager+compile a uniform decode-shape + forward is simpler and benefits from CUDA-graph friendly shapes. + + ``tokenizers.Tokenizer.from_file`` does not auto-prepend BOS — and Gemma 4 + is unusable without it (the model's logits collapse to a single + high-frequency vocab token if the very first input isn't BOS). We prepend + explicitly here; pass ``bos_token_id=None`` to disable. + """ + if eos_token_ids is None: + eos_token_ids = set() + + input_ids = tokenizer.encode(prompt).ids + if bos_token_id is not None and (not input_ids or input_ids[0] != bos_token_id): + input_ids = [bos_token_id] + input_ids + + temp_val = max(temperature, 1e-6) # avoid div-by-zero in the on-device sampler + temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + + sampled = None + with torch.no_grad(): + # Prefill, one token at a time. + for i, tok_id in enumerate(input_ids): + tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + + # First generated token from the last prefill step. + next_id = int(sampled.item()) + generated = [next_id] + + # Decode loop. + seq_len = len(input_ids) + for i in range(max_new_tokens - 1): + tok = torch.tensor([[next_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([seq_len + i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + next_id = int(sampled.item()) + generated.append(next_id) + if next_id in eos_token_ids: + break + + return tokenizer.decode(generated) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Eager inference on prequantized Gemma 4 31B-IT (CUDA)." + ) + parser.add_argument( + "--prequantized", + required=True, + help="Path to a quantized checkpoint directory.", + ) + parser.add_argument("--prompt", default="Hello", help="Input prompt.") + parser.add_argument( + "--max-new-tokens", + type=int, + default=128, + help="Maximum tokens to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature (0 = near-greedy).", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=4096, + help="KV cache length to allocate for this run.", + ) + parser.add_argument( + "--no-compile", + action="store_true", + help="Skip torch.compile (slower, but easier to debug).", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + parser.error("CUDA is required for inference.") + + print(f"Loading prequantized model from {args.prequantized}...") + model, config = load_prequantized_model( + args.prequantized, max_seq_len=args.max_seq_len + ) + _move_to_cuda(model, config) + model.eval() + + if not args.no_compile: + print("Compiling model with torch.compile...") + model = torch.compile(model, mode="default") + + tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") + from tokenizers import Tokenizer + + tokenizer = Tokenizer.from_file(tokenizer_path) + + # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). + eos_token_ids = {1, 50, 106} + + print(f"\nPrompt: {args.prompt}") + print("-" * 40) + + t0 = time.perf_counter() + output = generate( + model, + tokenizer, + args.prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + eos_token_ids=eos_token_ids, + ) + elapsed = time.perf_counter() - t0 + + print(output) + print("-" * 40) + print(f"Generated in {elapsed:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp new file mode 100644 index 00000000000..b12e1b87db9 --- /dev/null +++ b/examples/models/gemma4_31b/main.cpp @@ -0,0 +1,377 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Gemma 4 31B-IT runner for the CUDA ExecuTorch backend. +// +// Drives the prefill + decode methods produced by export.py. +// The exported model performs Gumbel-max sampling on-device and returns a +// single float token ID per call, so this runner only has to feed tokens +// in and decode them via the HuggingFace tokenizer. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef EXECUTORCH_BUILD_CUDA +#include +#endif + +DEFINE_string(model_path, "", "Model .pte file path."); +DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend."); +DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); +DEFINE_string(prompt, "Hello", "Prompt text."); +DEFINE_string( + prompt_file, + "", + "Path to file containing prompt text (overrides --prompt)."); +DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); +DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); +DEFINE_bool( + cuda_graph, + false, + "Enable CUDA graph capture for the decode method. CUDA only."); + +namespace llm = ::executorch::extension::llm; +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +using SizesType = executorch::aten::SizesType; + +static uint64_t read_token(const executorch::aten::Tensor& output) { + const void* ptr = output.const_data_ptr(); + float val = 0.0f; + +#ifdef EXECUTORCH_BUILD_CUDA + cudaPointerAttributes attrs{}; + bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + if (on_device) { + cudaError_t err = + cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + ET_LOG( + Error, + "read_token: cudaMemcpy D2H failed: %s", + cudaGetErrorString(err)); + return 0; + } + } else { + memcpy(&val, ptr, sizeof(float)); + } +#else + memcpy(&val, ptr, sizeof(float)); +#endif + + return static_cast(val); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_model_path.empty()) { + ET_LOG(Error, "Must specify --model_path"); + return 1; + } + if (FLAGS_tokenizer_path.empty()) { + ET_LOG(Error, "Must specify --tokenizer_path"); + return 1; + } + + llm::Stats stats; + +#ifdef EXECUTORCH_BUILD_CUDA + size_t gpu_free_bytes = 0, gpu_total_bytes = 0; + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_total_bytes = gpu_total_bytes; + stats.gpu_free_before_load_bytes = gpu_free_bytes; +#endif + + stats.model_load_start_ms = llm::time_in_ms(); + + // Tokenizer + auto tokenizer = std::make_unique(); + if (tokenizer->load(FLAGS_tokenizer_path) != tokenizers::Error::Ok) { + ET_LOG( + Error, + "Failed to load tokenizer from %s", + FLAGS_tokenizer_path.c_str()); + return 1; + } + + // Module: share_memory_arenas=true so prefill and decode see the same + // KV-cache memory (we exported with share_mutable_buffers=True). + std::vector data_files; + if (!FLAGS_data_path.empty()) { + data_files.push_back(FLAGS_data_path); + } + auto module = std::make_unique( + FLAGS_model_path, + data_files, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + + // Get metadata + auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to read model metadata"); + return 1; + } + +#ifdef EXECUTORCH_BUILD_CUDA + if (FLAGS_cuda_graph) { + executorch::runtime::BackendOptions<2> cuda_opts; + cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); + executorch::runtime::set_option("CudaBackend", cuda_opts.view()); + printf("CUDA graph enabled for decode method\n"); + } + + // Cross-method per-FQN weight sharing: prefill + decode share the same + // weight tensors and (more importantly) the same KV-cache buffers, so + // without this flag we would allocate them twice. MUST be set before + // load_method. + { + executorch::runtime::BackendOptions<1> backend_options; + auto set_err = + backend_options.set_option("weight_sharing_across_methods", true); + if (set_err != Error::Ok) { + ET_LOG( + Error, + "Failed to construct weight_sharing_across_methods option: %d", + static_cast(set_err)); + return 1; + } + auto opt_err = + executorch::runtime::set_option("CudaBackend", backend_options.view()); + if (opt_err != Error::Ok) { + ET_LOG( + Error, + "Failed to enable weight_sharing_across_methods: %d", + static_cast(opt_err)); + return 1; + } + } +#else + if (FLAGS_cuda_graph) { + ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); + } +#endif + + printf("Loading methods...\n"); + if (module->load_method("prefill") != Error::Ok) { + ET_LOG(Error, "Failed to load prefill method"); + return 1; + } + if (module->load_method("decode") != Error::Ok) { + ET_LOG(Error, "Failed to load decode method"); + return 1; + } + stats.model_load_end_ms = llm::time_in_ms(); + +#ifdef EXECUTORCH_BUILD_CUDA + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_free_after_load_bytes = gpu_free_bytes; +#endif + + auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); + + // Read prompt from file or flag + std::string prompt_text = FLAGS_prompt; + if (!FLAGS_prompt_file.empty()) { + std::ifstream f(FLAGS_prompt_file); + if (!f.is_open()) { + ET_LOG( + Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str()); + return 1; + } + prompt_text = std::string( + (std::istreambuf_iterator(f)), std::istreambuf_iterator()); + } + + // Encode prompt + auto encode_result = tokenizer->encode(prompt_text); + if (!encode_result.ok()) { + ET_LOG(Error, "Failed to encode prompt"); + return 1; + } + auto prompt_tokens = std::move(*encode_result); + int64_t num_prompt_tokens = static_cast(prompt_tokens.size()); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + stats.num_prompt_tokens = num_prompt_tokens; + + stats.inference_start_ms = llm::time_in_ms(); + + auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + +#ifdef EXECUTORCH_BUILD_CUDA + // CUDA build: model fuses the sampler. Pass temperature as a third input. + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); +#endif + + // --------------------------------------------------------------- + // Prefill (chunked to respect ring-buffer KV cache limit) + // --------------------------------------------------------------- + // Sliding layers use a ring buffer sized to 2×sliding_window. A single + // prefill call must not exceed this size, otherwise index_copy_ with + // wrapped indices produces non-deterministic results on CUDA. + int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; + { + auto get_result = module->get("get_max_prefill_chunk"); + if (get_result.ok()) { + max_prefill_chunk = get_result->toScalar().to(); + } + } + + uint64_t cur_token = 0; + int64_t prefill_pos = 0; + while (prefill_pos < num_prompt_tokens) { + int64_t chunk_len = + std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk); + + std::string run_method = (chunk_len == 1) ? "decode" : "prefill"; + + std::vector token_data( + prompt_tokens.begin() + prefill_pos, + prompt_tokens.begin() + prefill_pos + chunk_len); + std::vector pos_data(chunk_len); + for (int64_t i = 0; i < chunk_len; i++) { + pos_data[i] = prefill_pos + i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, S(chunk_len)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long); + + std::vector prefill_inputs; + prefill_inputs.push_back(EValue(tokens_tensor)); + prefill_inputs.push_back(EValue(pos_tensor)); +#ifdef EXECUTORCH_BUILD_CUDA + prefill_inputs.push_back(EValue(temp_tensor)); +#endif + + auto prefill_result = module->execute(run_method, prefill_inputs); + if (prefill_result.error() != Error::Ok) { + ET_LOG( + Error, "%s failed at pos %" PRId64, run_method.c_str(), prefill_pos); + return 1; + } + cur_token = read_token(prefill_result.get()[0].toTensor()); + prefill_pos += chunk_len; + } + + stats.prompt_eval_end_ms = llm::time_in_ms(); + double prefill_ms = + static_cast(stats.prompt_eval_end_ms - stats.inference_start_ms); + printf( + "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", + num_prompt_tokens, + prefill_ms, + num_prompt_tokens * 1000.0 / prefill_ms); + +#ifdef EXECUTORCH_BUILD_CUDA + // Synchronize CUDA device to ensure prefill's writes to shared mutable + // buffers (KV cache) are visible to the decode method, which may run on + // a different CUDA stream. + cudaDeviceSynchronize(); +#endif + + // --------------------------------------------------------------- + // Decode loop + // --------------------------------------------------------------- + int64_t pos = num_prompt_tokens; + std::vector decode_token_data = {static_cast(cur_token)}; + std::vector decode_pos_data = {pos}; + auto decode_tokens = from_blob( + decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); + auto decode_pos = from_blob( + decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); + + uint64_t prev_token = cur_token; + for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + decode_token_data[0] = static_cast(cur_token); + decode_pos_data[0] = pos; + + std::vector decode_inputs; + decode_inputs.push_back(EValue(decode_tokens)); + decode_inputs.push_back(EValue(decode_pos)); +#ifdef EXECUTORCH_BUILD_CUDA + decode_inputs.push_back(EValue(temp_tensor)); +#endif + + auto decode_result = module->execute("decode", decode_inputs); + if (decode_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + + prev_token = cur_token; + cur_token = read_token(decode_result.get()[0].toTensor()); + + if (step == 0) { + stats.first_token_ms = llm::time_in_ms(); + } + pos++; + + auto decode_str = tokenizer->decode(prev_token, cur_token); + if (decode_str.ok()) { + printf("%s", decode_str->c_str()); + fflush(stdout); + } + + if (eos_ids.find(cur_token) != eos_ids.end()) { + printf("\n"); + break; + } + } + + stats.inference_end_ms = llm::time_in_ms(); + printf("\n"); + + int64_t num_generated = pos - num_prompt_tokens; + stats.num_generated_tokens = num_generated; + double decode_ms = + static_cast(stats.inference_end_ms - stats.prompt_eval_end_ms); + printf( + "Decode: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", + num_generated, + decode_ms, + num_generated * 1000.0 / decode_ms); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + +#ifdef EXECUTORCH_BUILD_CUDA + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_free_after_generate_bytes = gpu_free_bytes; + stats.gpu_peak_usage_mb = + (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0; +#endif + + llm::print_report(stats); + return 0; +} diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md new file mode 100644 index 00000000000..1f8ccb16f6f --- /dev/null +++ b/examples/models/gemma4_31b/model.md @@ -0,0 +1,203 @@ +# Gemma 4 31B-IT — Architecture & Design Notes + +Developer reference for `model.py` and the `quant/` package. For +export/build/run instructions see [README.md](README.md). + +The model mirrors the `Gemma4ForConditionalGeneration` text stack from +HuggingFace transformers / vLLM, with the ExecuTorch customizations needed +for `torch.export(strict=True)`. + +## Architecture + +``` +Input tokens (B, T) + | + v +Embedding (vocab=262144, dim=5376) -> *= sqrt(hidden_size) (normalizer) + | + v ++--- Decoder Layer x60 -----------------------------------------+ +| | +| residual = x | +| RMSNorm -> Attention (sliding | full) -> RMSNorm -> +residual | +| residual = x | +| RMSNorm -> MLP (gate_proj, up_proj, down_proj, GELU-tanh) | +| -> RMSNorm -> +residual | +| x *= layer_scalar (per-layer buffer) | +| | ++----------------------------------------------------------------+ + | + v +RMSNorm -> LM Head (tied with embed) -> tanh(logits/30) * 30 + | + v +Gumbel-max sample(temperature) -> next token (B, 1) +``` + +Layer pattern (`5 sliding + 1 full`, repeated 10x — the last layer is full): + +``` +S S S S S F S S S S S F ... S S S S S F (S = sliding, F = full) +``` + +## Attention details + +Two attention flavors, selected by `config.layer_types[layer_idx]`: + +| Property | Sliding (50 layers) | Full (10 layers, idx 5,11,...,59) | +|---------------------|--------------------|-----------------------------------| +| `head_dim` | 256 | 512 | +| `num_kv_heads` | 16 | 4 | +| `num_heads` | 32 | 32 | +| RoPE θ | 10 000 | 1 000 000 | +| RoPE flavor | full neox | proportional, partial=0.25 | +| K = V | no | yes (no `v_proj`) | +| Causal mask | causal | causal | +| Window restriction | 1024 tokens | none | +| Q-norm / K-norm | RMSNorm w/ weight | RMSNorm w/ weight | +| V-norm | RMSNorm no weight | RMSNorm no weight | +| `scaling` | 1.0 | 1.0 | + +Notes: + +- **Proportional partial RoPE**: the inv_freq vector for full-attention layers + has the first `head_dim * partial_rotary_factor / 2 = 64` frequencies real + (computed with denominator `head_dim`, not `rotary_dim` — that's the + proportional part) and the remaining `head_dim/2 - 64 = 192` zero so cos=1 + and sin=0 (identity rotation) for the non-rotated dims. +- **K = V**: on full-attention layers `v_proj` is absent in the checkpoint + and `V` is taken from the pre-norm `K` projection. After `k_norm` / + RoPE on K and `v_norm` (weightless) on V the two diverge, so the cache + still stores them separately. +- **Mask construction**: a single boolean `(1, 1, T_q, T_kv)` mask is built + once per forward at the model level — one for sliding (causal AND + pos_q - pos_k < 1024), one for full (just causal). Layers pick whichever + matches their type and pass it to `F.scaled_dot_product_attention(..., + enable_gqa=True)`. +- **Gemma `scaling=1.0`**: unlike Gemma 2/3, Gemma 4 does not scale Q by + `query_pre_attn_scalar`; QK-norm handles attention magnitude. + +## Model parameters (text stack) + +| Parameter | Value | +|---------------------------------|------------| +| `vocab_size` | 262 144 | +| `hidden_size` | 5 376 | +| `intermediate_size` | 21 504 | +| `num_hidden_layers` | 60 | +| `num_attention_heads` | 32 | +| `num_key_value_heads` (sliding) | 16 | +| `head_dim` (sliding) | 256 | +| `num_global_key_value_heads` | 4 | +| `global_head_dim` | 512 | +| `sliding_window` | 1024 | +| `rms_norm_eps` | 1e-6 | +| `final_logit_softcapping` | 30.0 | +| `tie_word_embeddings` | true | +| `max_position_embeddings` | 262 144 | + +Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`, +`pre_feedforward_layernorm`, `post_feedforward_layernorm` — all +`RMSNorm` (multiplies by `weight` directly, not `(1 + weight)`). + +## Methods exported (`export.py`) + +| Method | Input | Output (sampled) | +|-----------|------------------------------------------------------------|------------------| +| `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float | +| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float | + +Both methods share the same KV-cache buffers via +`MemoryPlanningPass(share_mutable_buffers=True)` and +`emit_mutable_buffer_names=True`. The exported program performs Gumbel-max +sampling on-device and returns a single token ID per call so the C++ runner +only has to feed tokens. + +Prefill length is capped to the ring-buffer KV cache size +(`2 × sliding_window`) to avoid duplicate wrapped indices in +`index_copy_`. The C++ runner chunks longer prompts automatically using +the `get_max_prefill_chunk` constant method. Chunked prefill produces +identical logits to sequential one-token-at-a-time prefill. + +## Quantization + +Three modules in `quant/`: + +- **Recipe** (`recipe.py`): `QuantConfig` (bits, group_size, symmetric, + method) + `QuantRule` (regex pattern, config, optional layer filter) + + `QuantRecipe` (ordered rules, first match wins). Declares what to + quantize and how — says nothing about packing or backends. +- **Serialize** (`serialize.py`): `CanonicalQuantizedWeight` (int8 qdata + + bf16 scale + optional zero). `save()` / `load()` persist to safetensors + with a JSON header per weight. Packing-agnostic — any backend can read + the file. +- **Packer** (`pack_cuda.py`): converts `CanonicalQuantizedWeight` to + backend runtime format at load time via `pack_model()`. Dispatches per + parent module type (`nn.Linear` → `Int4TilePackedTo4dTensor` for + tinygemm). Extensible via a packers dict. + +The quantize-once flow: + +``` +quantize_and_save.py export.py / inference.py + | | + bf16 weights quantized checkpoint (safetensors) + | | + quantize_weight() load() + | | + CanonicalQuantizedWeight CanonicalQuantizedWeight + | | + save() pack_model() + | | + model.safetensors Int4TilePackedTo4dTensor (runtime) +``` + +`embed_tokens` and `lm_head` start tied; they are untied before +quantization so `lm_head` (a 5376→262 144 matmul, very expensive at decode) +gets quantized. The embedding gets INT8 per-axis quantization (nearly +lossless for index lookup). + +## Runtime buffer materialization + +After weight loading (via `pack_model()` or `from_hf_checkpoint()`), the +model's KV caches, RoPE tables, and scalar constants are still on the meta +device. `materialize_runtime_buffers(model, dtype, device)` in `model.py` +replaces them with real tensors: + +- KV caches → zeros in `dtype` (bf16 for inference, bf16 for export) +- RoPE tables → computed per-layer (sliding vs full, different θ and head_dim) +- `embed_normalizer`, `logit_softcap`, `cache_positions` → scalar constants + +Called by `export.py` (device="cpu" for tracing) and `inference.py` +(device="cuda" for eager execution). Having one function avoids duplicating +the RoPE computation and constant setup across scripts. + +## Customizations vs. vLLM / transformers reference + +These exist solely to make the model exportable / efficient under ExecuTorch: + +- **Boolean attention mask** built once per forward and shared across layers + of the same type, instead of HF's per-layer `_create_causal_mask`. +- **Ring-buffer KV cache** for sliding layers (`RingKVCache`, sized to + `2 × sliding_window`) saves memory for long sequences — positions wrap + via modulo and the attention mask reconstructs which slots are valid. + Full-attention layers use a flat `Gemma4KVCache` sized to `max_seq_len`. + Both use `index_copy_(dim=2, ...)` for trace-friendly updates. +- **Per-layer RoPE tables** registered as `persistent=False` buffers (sliding + uses full RoPE, full uses proportional partial RoPE — head_dim and θ + differ, so the table is not shared). +- **On-device Gumbel-max sampling** so the exported program emits a token + rather than a full logits tensor — keeps the runner GPU↔CPU traffic to a + single float per step. +- **Final-logit softcap baked into the graph**, applied before sampling. +- **Meta-device construction + assign-load** keeps peak memory small enough + to load the 31B-parameter checkpoint on one machine. + +## Shared primitives + +The numerically-sensitive math primitives are imported from +`examples.models.gemma4.text_decoder` and shared with the Gemma 4 E2B/E4B +example: `RMSNorm`, `RMSNormNoWeight`, `Gemma4MLP`, `Gemma4KVCache`, +`precompute_freqs_cis`, `apply_rotary_emb`. The 31B-specific pieces +(attention with K=V branch, decoder layer, top-level model with softcap + +sampling, checkpoint loader) live in `model.py`. diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py new file mode 100644 index 00000000000..7366a57bf46 --- /dev/null +++ b/examples/models/gemma4_31b/model.py @@ -0,0 +1,703 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Gemma 4 31B-IT — export-friendly reference implementation for ExecuTorch. + +Model definition designed for torch.export(strict=True) with the CUDA backend. +All stateful buffers (KV cache, RoPE inv_freq) are registered buffers so they +are captured by share_mutable_buffers across prefill/decode. The numerically +sensitive primitives — RMSNorm, GELU-tanh MLP, proportional/full RoPE, and +the BHSD KV cache — are imported from ``examples.models.gemma4.text_decoder`` +so the 31B and E2B/E4B paths share them. + +Reference: + - HF transformers: src/transformers/models/gemma4/modeling_gemma4.py + - vLLM: vllm/model_executor/models/gemma4.py + +Architecture highlights for the 31B dense variant: + - 60 decoder layers with hybrid attention: every 6th layer is "full" attention + (idx 5, 11, ..., 59 — 10 layers); the remaining 50 use sliding-window + attention with window=1024. + - Sliding layers: head_dim=256, num_kv_heads=16, full RoPE, theta=10000. + - Full layers: head_dim=512, num_kv_heads=4, K=V (no v_proj), and + "proportional" partial RoPE (factor=0.25, theta=1_000_000). + - Q-norm and K-norm with learnable scale; V-norm without scale. + - Per-layer scalar (loaded buffer) multiplied at the end of each layer. + - Final logits are soft-capped: tanh(logits / 30) * 30. + - Embedding is scaled by sqrt(hidden_size) before layer 0. + - Embedding and lm_head are tied (a single weight, untied for quantization + in the export step so lm_head can be 4-bit). +""" + +import json +import os +import re +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torch.nn as nn + +# Shared primitives lifted out of the gemma4 (E2B/E4B) example. These are the +# bits whose semantics are identical for both variants — RMSNorm, the GELU-tanh +# MLP, the proportional/full RoPE table builder, and the BHSD KV cache. +from executorch.examples.models.gemma4.text_decoder import ( + apply_rotary_emb, + Gemma4KVCache, + Gemma4MLP, + precompute_freqs_cis, + RMSNorm, + RMSNormNoWeight, +) +from executorch.examples.models.gemma4_31b.sampler import sample +from torch.nn import functional as F + + +# --------------------------------------------------------------------------- +# Ring-buffer KV cache for sliding window attention + + +class RingKVCache(nn.Module): + """Ring-buffer KV cache for sliding window attention. + + Sized to ``window_size * 2`` (not ``max_seq_len``), saving memory for + long sequences. Positions wrap via modulo; old entries outside the + window are masked out by ``_build_masks``. + """ + + def __init__( + self, + max_batch_size: int, + window_size: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + self.window_size = window_size + self.buf_size = window_size * 2 + cache_shape = (max_batch_size, num_kv_heads, self.buf_size, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False) + self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # seq_len must not exceed buf_size, otherwise wrapped indices contain + # duplicates and index_copy_ is non-deterministic on CUDA. The C++ + # runner must chunk prefill to respect this limit. + wrapped = input_pos % self.buf_size + self.k_cache.index_copy_(2, wrapped, k_val) + self.v_cache.index_copy_(2, wrapped, v_val) + return self.k_cache, self.v_cache + + +# --------------------------------------------------------------------------- +# Config + + +@dataclass +class Gemma4_31BConfig: + # Embedding / shape + vocab_size: int = 262144 + hidden_size: int = 5376 + intermediate_size: int = 21504 + num_hidden_layers: int = 60 + + # Attention shape (sliding layers — also the "default" path) + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + head_dim: int = 256 + + # Attention shape (full-attention layers) + num_global_key_value_heads: int = 4 + global_head_dim: int = 512 + attention_k_eq_v: bool = ( + True # full layers: V is derived from the same projection as K + ) + + # RoPE — split per layer type + sliding_rope_theta: float = 10_000.0 + full_rope_theta: float = 1_000_000.0 + full_partial_rotary_factor: float = 0.25 # proportional RoPE for full attention + + # Norm / activation + rms_norm_eps: float = 1e-6 + hidden_activation: str = "gelu_pytorch_tanh" + + # Sampling / output + final_logit_softcapping: float = 30.0 + tie_word_embeddings: bool = True + + # Sliding window + sliding_window: int = 1024 + + # Hybrid attention pattern + layer_types: list = field(default_factory=list) + + # Runtime + max_seq_len: int = 4096 + + def __post_init__(self): + if not self.layer_types: + # Default hybrid pattern: 5 sliding then 1 full, repeated. + self.layer_types = [ + "full_attention" if (i + 1) % 6 == 0 else "sliding_attention" + for i in range(self.num_hidden_layers) + ] + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError( + f"layer_types length {len(self.layer_types)} != " + f"num_hidden_layers {self.num_hidden_layers}" + ) + + @staticmethod + def from_hf_config(config_path: str) -> "Gemma4_31BConfig": + with open(config_path, "r") as f: + cfg = json.load(f) + if "text_config" in cfg: + cfg = cfg["text_config"] + + rope_params = cfg.get("rope_parameters", {}) + sliding_rope = rope_params.get("sliding_attention", {}) + full_rope = rope_params.get("full_attention", {}) + + return Gemma4_31BConfig( + vocab_size=cfg.get("vocab_size", 262144), + hidden_size=cfg.get("hidden_size", 5376), + intermediate_size=cfg.get("intermediate_size", 21504), + num_hidden_layers=cfg.get("num_hidden_layers", 60), + num_attention_heads=cfg.get("num_attention_heads", 32), + num_key_value_heads=cfg.get("num_key_value_heads", 16), + head_dim=cfg.get("head_dim", 256), + num_global_key_value_heads=cfg.get("num_global_key_value_heads", 4), + global_head_dim=cfg.get("global_head_dim", 512), + attention_k_eq_v=cfg.get("attention_k_eq_v", True), + sliding_rope_theta=sliding_rope.get("rope_theta", 10_000.0), + full_rope_theta=full_rope.get("rope_theta", 1_000_000.0), + full_partial_rotary_factor=full_rope.get("partial_rotary_factor", 0.25), + rms_norm_eps=cfg.get("rms_norm_eps", 1e-6), + hidden_activation=cfg.get("hidden_activation", "gelu_pytorch_tanh"), + final_logit_softcapping=cfg.get("final_logit_softcapping", 30.0), + tie_word_embeddings=cfg.get("tie_word_embeddings", True), + sliding_window=cfg.get("sliding_window", 1024), + layer_types=cfg.get("layer_types", []), + ) + + +# --------------------------------------------------------------------------- +# Attention — single class, branches on layer type via config +# +# RMSNorm, Gemma4MLP, the RoPE helpers, and Gemma4KVCache are imported from +# examples.models.gemma4.text_decoder so the two Gemma 4 variants share their +# numerically-sensitive primitives. + + +class Gemma4Attention(nn.Module): + """Gemma 4 attention with QK-norm, per-layer head_dim, RoPE, KV cache, and SDPA. + + The same class handles both sliding and full attention; the per-layer + config picks head_dim, num_kv_heads, RoPE flavor, and the K=V optimization. + """ + + def __init__(self, config: Gemma4_31BConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + layer_type = config.layer_types[layer_idx] + self.is_sliding = layer_type == "sliding_attention" + + if self.is_sliding: + self.head_dim = config.head_dim + self.n_kv_heads = config.num_key_value_heads + self.rope_theta = config.sliding_rope_theta + self.partial_rotary = 1.0 + self.k_eq_v = False + else: + self.head_dim = config.global_head_dim + self.n_kv_heads = config.num_global_key_value_heads + self.rope_theta = config.full_rope_theta + self.partial_rotary = config.full_partial_rotary_factor + self.k_eq_v = config.attention_k_eq_v + + self.n_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.scaling = 1.0 # Gemma 4 uses scale=1; QK-norm handles normalization. + + # Linear projections. v_proj is omitted on K=V layers to match the checkpoint. + self.q_proj = nn.Linear( + self.hidden_size, self.n_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.n_kv_heads * self.head_dim, bias=False + ) + if not self.k_eq_v: + self.v_proj = nn.Linear( + self.hidden_size, self.n_kv_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.n_heads * self.head_dim, self.hidden_size, bias=False + ) + + # Q/K norm have learnable weight; V norm is weightless. + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = RMSNormNoWeight(self.head_dim, eps=config.rms_norm_eps) + + # Precomputed RoPE table for this layer (per-layer because head_dim + # and theta differ between sliding and full attention). For full + # attention layers we pass freq_base_dim=head_dim so the zero-padded + # inv_freq matches HF's "proportional" partial RoPE. + if self.is_sliding: + rotary_dim = self.head_dim + freq_base_dim = None + else: + rotary_dim = int(self.head_dim * self.partial_rotary) + freq_base_dim = self.head_dim + freqs_cos, freqs_sin = precompute_freqs_cis( + rotary_dim, + config.max_seq_len, + theta=self.rope_theta, + freq_base_dim=freq_base_dim, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + # KV cache. Sliding layers use a ring buffer (2x window) to save + # memory; full layers use a flat buffer (max_seq_len). + if self.is_sliding: + self.kv_cache = RingKVCache( + max_batch_size=1, + window_size=config.sliding_window, + num_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + ) + else: + self.kv_cache = Gemma4KVCache( + max_batch_size=1, + max_seq_len=config.max_seq_len, + num_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + use_index_copy=True, + ) + + def forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + # raw_kv is the linear output before any norm — needed for K=V layers + # so V can be derived from the same tensor as K (post-norm differently). + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + # Norms applied per-head (HF unflatten -> norm -> flatten pattern). + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + # Move to BHSD for SDPA / KV cache. + q = q.transpose(1, 2) # (B, H, T, D) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE on Q and K only (V is not rotated). cos/sin are gathered for + # the current positions to avoid baking the full table into the graph. + cos = self.freqs_cos[input_pos] + sin = self.freqs_sin[input_pos] + q, k = apply_rotary_emb(q, k, cos, sin) + + # Update cache and read back full K/V. + k, v = self.kv_cache.update(input_pos, k, v) + + # SDPA with explicit additive mask (already includes causal + + # sliding-window masking; built once per forward at the model level). + # `scale=1.0` matches HF Gemma 4 — Q-norm/K-norm have absorbed the + # 1/sqrt(d) factor into their trained weights, so the standard SDPA + # default of 1/sqrt(head_dim) would over-divide. enable_gqa lets the + # kernel handle the head ratio without us materializing expanded K/V. + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + is_causal=False, + enable_gqa=True, + scale=self.scaling, + ) + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + +# --------------------------------------------------------------------------- +# Decoder block — Gemma's "norm sandwich" pattern. + + +class Gemma4DecoderLayer(nn.Module): + def __init__(self, config: Gemma4_31BConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + + self.self_attn = Gemma4Attention(config, layer_idx) + self.mlp = Gemma4MLP(config.hidden_size, config.intermediate_size) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Per-layer scalar (loaded from checkpoint) — multiplied at the end of + # each layer. Kept as a buffer (not nn.Parameter) so it isn't quantized. + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + sliding_mask: torch.Tensor, + full_mask: torch.Tensor, + ) -> torch.Tensor: + attn_mask = sliding_mask if self.is_sliding else full_mask + + residual = x + h = self.input_layernorm(x) + h = self.self_attn(h, input_pos, attn_mask) + h = self.post_attention_layernorm(h) + x = residual + h + + residual = x + h = self.pre_feedforward_layernorm(x) + h = self.mlp(h) + h = self.post_feedforward_layernorm(h) + x = residual + h + + return x * self.layer_scalar + + +# --------------------------------------------------------------------------- +# Top-level model + + +class Gemma4_31B(nn.Module): + def __init__(self, config: Gemma4_31BConfig): + super().__init__() + self.config = config + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [Gemma4DecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # Held separately so it can be untied + quantized at export time. + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Constants (registered as buffers so they move with .to(device)). + self.register_buffer( + "embed_normalizer", + torch.tensor(config.hidden_size**0.5), + persistent=False, + ) + self.register_buffer( + "logit_softcap", + torch.tensor(config.final_logit_softcapping), + persistent=False, + ) + # cache_positions[i] = i — used to build attention masks without + # introducing dynamic-shape tensors at runtime. + self.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long), + persistent=False, + ) + + def _build_masks( + self, input_pos: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build boolean (B=1, H=1, T_q, T_kv) masks for full and sliding attention. + + True = attend. Built once per forward, shared across layers of the + same type. Full mask is (T_q, max_seq_len); sliding mask is + (T_q, buf_size) where buf_size = 2 * sliding_window. + """ + # Full attention mask: (T_q, max_seq_len) + cache_pos = self.cache_positions # (max_seq_len,) + q_pos = input_pos.unsqueeze(1) # (T_q, 1) + causal = q_pos >= cache_pos.unsqueeze(0) + full_mask = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, max_seq_len) + + # Sliding attention mask over ring buffer: (T_q, buf_size) + buf_size = self.config.sliding_window * 2 + seq_len = input_pos.shape[0] + total_written = input_pos[0] + seq_len + j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device) + ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size + delta = q_pos - ring_pos.unsqueeze(0) + sliding = (ring_pos >= 0) & (delta >= 0) & (delta < self.config.sliding_window) + sliding_mask = sliding.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, buf_size) + + return sliding_mask, full_mask + + def forward( + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + temperature: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run the model. + + Args: + tokens: (B, T) token IDs. + input_pos: (T,) absolute positions for RoPE / KV cache. + temperature: optional 1-D float tensor controlling on-device sampling. + When provided, returns sampled tokens (B, 1) via Gumbel-max; + when None (e.g. eager eval), returns full logits (B, T, V) with + soft-capping applied so callers see post-cap values. + + Returns: + (B, 1) token IDs when sampling, else (B, T, V) float32 logits. + """ + x = self.embed_tokens(tokens) * self.embed_normalizer + + sliding_mask, full_mask = self._build_masks(input_pos) + for layer in self.layers: + x = layer(x, input_pos, sliding_mask, full_mask) + + x = self.norm(x) + + if temperature is None: + logits = self.lm_head(x).float() + cap = self.logit_softcap.float() + return torch.tanh(logits / cap) * cap + + # Decode-time fast path: only materialize logits for the last token. + last = self.lm_head(x[:, -1, :]).float() + cap = self.logit_softcap.float() + last = torch.tanh(last / cap) * cap + return sample(last, temperature) + + # ---------------- checkpoint loading ---------------- + + @staticmethod + def from_hf_checkpoint( + model_dir: str, max_seq_len: int = 4096 + ) -> tuple["Gemma4_31B", Gemma4_31BConfig]: + """Build the model on `meta` and load weights from the HF safetensors checkpoint. + + Uses lazy shard-by-shard loading + assign=True so peak memory stays at + roughly one shard's worth of weights. + """ + config = Gemma4_31BConfig.from_hf_config(os.path.join(model_dir, "config.json")) + config.max_seq_len = max_seq_len + + print( + f"Building Gemma4_31B on meta (layers={config.num_hidden_layers}, " + f"hidden={config.hidden_size}, max_seq_len={max_seq_len})..." + ) + with torch.device("meta"): + model = Gemma4_31B(config) + + print(f"Loading weights from {model_dir}...") + state_dict = _load_and_remap_checkpoint(model_dir, config) + + # Tied embeddings: copy embedding weight into lm_head when missing. + if "lm_head.weight" not in state_dict and "embed_tokens.weight" in state_dict: + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"] + + missing, unexpected = model.load_state_dict( + state_dict, strict=False, assign=True + ) + + # Runtime buffers (KV caches, RoPE tables, masks) are zero-initialized + # and not in the checkpoint — those are the "expected" missing keys. + runtime_prefixes = ( + ".kv_cache.", + ".freqs_cos", + ".freqs_sin", + "embed_normalizer", + "logit_softcap", + "cache_positions", + ) + actual_missing = set(missing) + expected = {k for k in actual_missing if any(p in k for p in runtime_prefixes)} + extra = actual_missing - expected + if extra: + print(f" WARNING: missing weight keys: {sorted(extra)[:10]}") + if unexpected: + print(f" WARNING: unexpected keys: {sorted(unexpected)[:10]}") + print( + f" Loaded {len(state_dict)} tensors " + f"({len(expected)} runtime buffers OK)" + ) + return model, config + + +# --------------------------------------------------------------------------- +# Weight loading utilities + + +# HuggingFace key -> our model key. Patterns use `{}` for the layer index. +_HF_KEY_MAP = { + "model.embed_tokens.weight": "embed_tokens.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "lm_head.weight", + # Per-layer norms + "model.layers.{}.input_layernorm.weight": "layers.{}.input_layernorm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_layernorm.weight", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.pre_feedforward_layernorm.weight", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_feedforward_layernorm.weight", + "model.layers.{}.layer_scalar": "layers.{}.layer_scalar", + # Attention projections + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.self_attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.self_attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.self_attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.self_attn.o_proj.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.self_attn.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.self_attn.k_norm.weight", + # MLP + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.gate_proj.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.up_proj.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.down_proj.weight", +} + +# Multimodal keys we deliberately ignore for the text-only export. +_IGNORED_PREFIXES = ( + "model.vision_tower.", + "model.embed_vision.", +) + + +def _hf_to_model_key(hf_key: str) -> Optional[str]: + # Gemma4ForConditionalGeneration stores the LM under model.language_model.* + norm = hf_key + if norm.startswith("model.language_model."): + norm = norm.replace("model.language_model.", "model.", 1) + + if norm.startswith(_IGNORED_PREFIXES): + return None + + for hf_pat, model_pat in _HF_KEY_MAP.items(): + if "{}" not in hf_pat: + if norm == hf_pat: + return model_pat + continue + regex = re.escape(hf_pat).replace(r"\{\}", r"(\d+)") + m = re.fullmatch(regex, norm) + if m: + return model_pat.replace("{}", m.group(1), 1) + return None + + +def _load_and_remap_checkpoint(model_dir: str, config: Gemma4_31BConfig) -> dict: + """Stream-load safetensors shards and remap keys to model state_dict keys.""" + from safetensors import safe_open + + index_path = os.path.join(model_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + shard_files = sorted(set(index["weight_map"].values())) + elif os.path.exists(os.path.join(model_dir, "model.safetensors")): + shard_files = ["model.safetensors"] + else: + raise FileNotFoundError(f"No safetensors checkpoint in {model_dir}") + + state_dict: dict[str, torch.Tensor] = {} + skipped = 0 + for shard_file in shard_files: + shard_path = os.path.join(model_dir, shard_file) + with safe_open(shard_path, framework="pt", device="cpu") as f: + for ckpt_key in f.keys(): + model_key = _hf_to_model_key(ckpt_key) + if model_key is None: + skipped += 1 + continue + tensor = f.get_tensor(ckpt_key) + # layer_scalar in checkpoint is shape (1,) bf16 — keep as-is. + state_dict[model_key] = tensor + if skipped > 0: + print(f" Skipped {skipped} non-text keys (vision tower, etc.)") + return state_dict + + +# --------------------------------------------------------------------------- +# Runtime buffer materialization + + +def materialize_runtime_buffers( + model: Gemma4_31B, + dtype: torch.dtype, + device: str = "cpu", +) -> None: + """Replace meta-device buffers with real tensors and set runtime constants. + + Called after weight loading to fill in KV caches (zeros), RoPE tables + (computed), and scalar constants. Only touches buffers still on the meta + device — loaded (non-meta) buffers are left in place. + """ + config = model.config + + for fqn, buf in list(model.named_buffers()): + if buf.device.type != "meta": + continue + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + is_kv = ".kv_cache." in fqn + target_dtype = dtype if is_kv else torch.float32 + if buf.dtype == torch.bool: + target_dtype = torch.bool + parent.register_buffer( + parts[-1], + torch.zeros(buf.shape, dtype=target_dtype, device=device), + persistent=False, + ) + + for layer in model.layers: + attn = layer.self_attn + if attn.is_sliding: + rotary_dim, freq_base_dim = attn.head_dim, None + else: + rotary_dim = int(attn.head_dim * attn.partial_rotary) + freq_base_dim = attn.head_dim + cos, sin = precompute_freqs_cis( + rotary_dim, + config.max_seq_len, + theta=attn.rope_theta, + freq_base_dim=freq_base_dim, + ) + attn.register_buffer("freqs_cos", cos.to(device), persistent=False) + attn.register_buffer("freqs_sin", sin.to(device), persistent=False) + + model.register_buffer( + "embed_normalizer", + torch.tensor(config.hidden_size**0.5, device=device), + persistent=False, + ) + model.register_buffer( + "logit_softcap", + torch.tensor(config.final_logit_softcapping, device=device), + persistent=False, + ) + model.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long, device=device), + persistent=False, + ) diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md new file mode 100644 index 00000000000..01a74434487 --- /dev/null +++ b/examples/models/gemma4_31b/quant/README.md @@ -0,0 +1,88 @@ +# quant/ + +Packing-agnostic quantization framework: **recipe → quantize → serialize → pack**. + +## Files + +| File | Concern | Depends on | +|---|---|---| +| `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing | +| `quantize.py` | **Computation** — produces canonical weights from fp weights | recipe, torchao | +| `serialize.py` | **Data format** — saves/loads canonical weights to safetensors | recipe | +| `pack.py` | **Packing dispatch** — walks model, dispatches to per-module packers | serialize | +| `pack_cuda.py` | **CUDA packing** — converts canonical to tinygemm/intx runtime format | pack, serialize | + +## Data flow + +``` +QuantRecipe → quantize_model() → CanonicalQuantizedWeight → save() → file → load() → CanonicalQuantizedWeight → pack_model() → runtime model +``` + +`CanonicalQuantizedWeight` is the interchange point — int8 qdata + bf16 +scale + optional zero + config. Everything left of it is backend-agnostic. +Everything right is backend-specific. + +## Adding a new backend + +Write a `pack_.py` with per-module packers and a default registry: + +```python +def pack_linear_for_metal(module, weights): ... +DEFAULT_METAL_PACKERS = {nn.Linear: pack_linear_for_metal} +``` + +Call `pack_model(model, quantized, unquantized, packers=DEFAULT_METAL_PACKERS)`. +No changes to recipe, quantize, or serialize. + +Things to consider: + +- **Recipes may need to be backend-aware.** Each backend's kernels have + different constraints (e.g., Metal's `fpa4w` is INT4-only — no INT8 linear + kernel, so the sensitive recipe's 8-bit edge layers would need to be INT4 + or dequantized to bf16). Define per-backend recipes or validate recipe + compatibility at pack time. +- **Source transforms before packing.** Some backends replace model modules + (e.g., MLX swaps `FusedMoEExperts` → `SwitchMLP`, Metal swaps to + `MetalMoEExperts`). These transforms change the module types that + packers dispatch on, so they must run before `pack_model()`. For dense + models (no MoE) this is not needed. +- **Embedding quantization.** Not all backends have a quantized embedding + gather kernel. The packer can dequantize to bf16 at load time — the + disk savings from the canonical format still apply. + +## Adding a new model + +1. Define a `QuantRecipe` with rules for the model's FQN patterns. +2. If the model has custom module types (e.g., `FusedMoEExperts`), write a + per-module packer and extend the packers dict: + ```python + packers = {**DEFAULT_CUDA_PACKERS, FusedMoEExperts: pack_moe_experts} + ``` +3. No changes to the quant package itself. + +## On-disk format + +Safetensors with a `format_version` in the header. Per quantized weight: +`{fqn}.qdata` (int8, nibble-packed for 4-bit), `{fqn}.scale` (bf16), +optionally `{fqn}.zero` (bf16). Header JSON records bits, group_size, +symmetric, and method per weight. Unquantized weights stored as-is. + +## TODO + +- `pack_metal.py` — Metal backend packer. Convert canonical INT4 to + `UIntxWeightOnlyConfig` subclass (torchao experimental) for the + `torchao::_linear_fp_act_4bit_weight` kernel. For MoE models, pack + expert weights into Metal's `gather_qmv` format (asymmetric, unsigned + INT4 with scale + bias buffers). + +- `pack_mlx.py` — MLX backend packer. Convert canonical INT4 to + `IntxWeightOnlyConfig` subclass for the `mlx::gather_qmm` kernel. + For MoE models, stack per-expert weights into `SwitchLinear` format. + +- `gguf.py` — read a GGUF file and convert to `CanonicalQuantizedWeight` + dicts, enabling `load() → pack_model()` from community-quantized GGUF + checkpoints without re-quantizing from bf16. Maps GGUF quant types + (Q4_K, Q6_K, Q8_0, etc.) to `QuantConfig` and unpacks super-blocks + into the canonical qdata + scale + zero layout. For CUDA packing, + Q6_K would be widened to 8-bit (`pack_int8_for_cuda`) since there is + no 6-bit CUDA kernel — lossless, ~33% more memory than true 6-bit. diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py new file mode 100644 index 00000000000..23d321f0c0b --- /dev/null +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .pack import ModulePackerFn, pack_model # noqa: F401 +from .pack_cuda import ( # noqa: F401 + DEFAULT_CUDA_PACKERS, + load_and_pack_for_cuda, + pack_embedding_for_cuda, + pack_int4_for_cuda, + pack_int8_for_cuda, + pack_linear_for_cuda, +) +from .quantize import quantize_model, quantize_weight # noqa: F401 +from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 +from .serialize import ( # noqa: F401 + CanonicalQuantizedWeight, + deserialize, + load, + save, + serialize, +) diff --git a/examples/models/gemma4_31b/quant/pack.py b/examples/models/gemma4_31b/quant/pack.py new file mode 100644 index 00000000000..544e96287e9 --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Backend-agnostic model packing: canonical weights → runtime model. + +``pack_model`` walks a model's quantized weights, groups them by parent +module, and dispatches to per-module packer functions. Each backend +(``pack_cuda.py``, future ``pack_metal.py``) provides its own packers dict +mapping module types to packer functions. + +Pure logic — no file I/O, no backend imports. +""" + +from collections import defaultdict +from typing import Callable + +import torch +import torch.nn as nn + +from .serialize import CanonicalQuantizedWeight + +# Packer signature: receives the module + a dict of its quantized weights +# (keyed by attribute name, e.g., {"weight": CQW}), modifies module in-place. +ModulePackerFn = Callable[[nn.Module, dict[str, CanonicalQuantizedWeight]], None] + + +def _assign_unquantized(model: nn.Module, unquantized: dict[str, torch.Tensor]) -> None: + """Assign plain (unquantized) tensors to model parameters and buffers.""" + model_sd_keys = set(model.state_dict().keys()) + for fqn, tensor in unquantized.items(): + if fqn not in model_sd_keys: + continue + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + attr_name = parts[-1] + if isinstance(getattr(parent, attr_name, None), nn.Parameter): + setattr(parent, attr_name, nn.Parameter(tensor, requires_grad=False)) + else: + parent.register_buffer(attr_name, tensor) + + +def pack_model( + model: nn.Module, + quantized: dict[str, CanonicalQuantizedWeight], + unquantized: dict[str, torch.Tensor], + packers: dict[type, ModulePackerFn], +) -> None: + """Pack canonical weights into ``model`` using the given packers. + + Groups quantized weights by their parent module, then dispatches to the + appropriate per-module packer based on the module's type. Models with + custom module types (e.g., ``FusedMoEExperts``) extend ``packers``. + + Pure logic — no file I/O, no backend dependency. + """ + + _assign_unquantized(model, unquantized) + + module_weights: dict[str, dict[str, CanonicalQuantizedWeight]] = defaultdict(dict) + for fqn, cw in quantized.items(): + parts = fqn.rsplit(".", 1) + parent_fqn = parts[0] if len(parts) > 1 else "" + attr = parts[-1] + module_weights[parent_fqn][attr] = cw + + for parent_fqn, weights in module_weights.items(): + module = model.get_submodule(parent_fqn) if parent_fqn else model + packer = packers.get(type(module)) + if packer is None: + raise ValueError( + f"No packer registered for {type(module).__name__} at '{parent_fqn}'. " + f"Registered types: {[t.__name__ for t in packers]}." + ) + packer(module, weights) + + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in checkpoint " + f"(model/checkpoint version mismatch?)" + ) + + for p in model.parameters(): + p.requires_grad_(False) diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py new file mode 100644 index 00000000000..039f2cbf7ba --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA packer: canonical → CUDA runtime format. + +Provides per-module packers for the CUDA backend (INT4 via tinygemm, +INT8 via ``IntxUnpackedToInt8Tensor``) and ``load_and_pack_for_cuda`` +as a convenience I/O wrapper. + +The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. +""" + +import torch +import torch.nn as nn + +from .pack import ModulePackerFn, pack_model # noqa: F401 +from .serialize import CanonicalQuantizedWeight, load + + +# --------------------------------------------------------------------------- +# Low-level: canonical → Int4TilePackedTo4dTensor (one weight at a time) + + +def pack_int4_for_cuda( + cw: CanonicalQuantizedWeight, + device: str = "cuda", +) -> nn.Parameter: + """Convert a canonical 4-bit weight to ``Int4TilePackedTo4dTensor``. + + Pads K to a multiple of 1024 and N to a multiple of 8 (tinygemm + requirements), nibble-packs, then tile-packs via the CUDA kernel. + Returns an ``nn.Parameter`` wrapping the subclass tensor **on CUDA**. + """ + from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import ( + Int4TilePackedTo4dTensor, + ) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + from torchao.utils import find_multiple + + assert cw.config.bits == 4, f"Expected 4-bit, got {cw.config.bits}" + assert cw.qdata.ndim == 2, ( + f"pack_int4_for_cuda requires 2D weight (nn.Linear), got {cw.qdata.ndim}D " + f"shape {tuple(cw.qdata.shape)}." + ) + + original_shape = cw.qdata.shape + N, K = original_shape + gs = cw.config.group_size + inner_k_tiles = 8 + + K_padded = find_multiple(K, 1024) + N_padded = find_multiple(N, 8) + + int_data = cw.qdata.to(torch.int32) + if K_padded != K or N_padded != N: + int_data = torch.nn.functional.pad(int_data, (0, K_padded - K, 0, N_padded - N)) + + scale = cw.scale + n_groups_orig = K // gs + n_groups_padded = K_padded // gs + if n_groups_padded != n_groups_orig or N_padded != N: + scale = torch.nn.functional.pad( + scale, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) + ) + + if cw.zero is not None: + zero = cw.zero + if n_groups_padded != n_groups_orig or N_padded != N: + zero = torch.nn.functional.pad( + zero, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) + ) + else: + # Symmetric: qdata is unsigned [0, 15] (shifted +8 from signed [-8, 7]). + # Standard convention: weight = (q - zp_std) * scale, so zp_std = 8. + zero = torch.full_like(scale, 8.0) + + int_data = int_data.to(device) + scale = scale.to(device) + zero = zero.to(device) + + # Convert zero from standard convention (weight = (q - zp_std) * scale) + # to tinygemm convention (weight = (q - 8) * scale + zp_tg). + # Derivation: (q - zp_std) * scale = (q - 8) * scale + zp_tg + # → zp_tg = (8 - zp_std) * scale + tinygemm_zero = (8 - zero.to(torch.float32)) * scale.to(torch.float32) + + # Tinygemm nibble convention: even index in HIGH nibble, odd in LOW. + # (This differs from serialize.py's _nibble_pack which uses the opposite + # convention for on-disk storage — both are valid, they serve different + # consumers.) + int_data_u8 = (int_data[:, ::2] << 4 | int_data[:, 1::2]).to(torch.uint8) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data_u8.contiguous(), inner_k_tiles + ) + + scale_and_zero = pack_tinygemm_scales_and_zeros( + scale.to(torch.bfloat16), tinygemm_zero.to(torch.bfloat16), torch.bfloat16 + ) + + subclass = Int4TilePackedTo4dTensor( + qdata=packed_weight, + scale_and_zero=scale_and_zero, + block_size=[1, gs], + shape=torch.Size(original_shape), + ) + return nn.Parameter(subclass, requires_grad=False) + + +# --------------------------------------------------------------------------- +# Per-module packers + + +def pack_int8_for_cuda( + cw: CanonicalQuantizedWeight, +) -> nn.Parameter: + """Convert a canonical 8-bit weight to ``IntxUnpackedToInt8Tensor``. + + Unlike INT4 (which needs tinygemm tile packing), INT8 weights are stored + unpacked. The subclass carries int8 qdata + scales and dequantizes during + matmul — AOTI fuses the ``dequantize → mm`` pattern in the compiled graph. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + + assert cw.config.bits == 8, f"Expected 8-bit, got {cw.config.bits}" + assert cw.qdata.ndim == 2, f"Expected 2D weight, got {cw.qdata.ndim}D" + + N, K = cw.qdata.shape + n_groups = K // cw.config.group_size + scale = cw.scale.to(torch.bfloat16).reshape(N, n_groups) + zero_point = ( + cw.zero.to(torch.int8).reshape(N, n_groups) + if cw.zero is not None + else torch.zeros(N, n_groups, dtype=torch.int8) + ) + + subclass = IntxUnpackedToInt8Tensor( + qdata=cw.qdata, + scale=scale, + zero_point=zero_point, + target_dtype=torch.int8, + block_size=(1, cw.config.group_size), + dtype=torch.bfloat16, + activation_quantization=None, + ) + return nn.Parameter(subclass, requires_grad=False) + + +def pack_linear_for_cuda( + module: nn.Module, weights: dict[str, CanonicalQuantizedWeight] +) -> None: + """Pack a quantized ``nn.Linear`` for CUDA. + + 4-bit weights use ``Int4TilePackedTo4dTensor`` (tinygemm kernel, requires + CUDA for packing). 8-bit weights use ``IntxUnpackedToInt8Tensor`` (AOTI + fuses the dequantize-matmul pattern). Both stay as tensor subclasses so + the export graph captures quantized ops. + """ + cw = weights["weight"] + if cw.config.bits == 4: + packed = pack_int4_for_cuda(cw, device="cuda") + module.weight = nn.Parameter(packed.data.to("cpu"), requires_grad=False) + torch.cuda.empty_cache() + elif cw.config.bits == 8: + module.weight = pack_int8_for_cuda(cw) + else: + raise ValueError(f"Unsupported bit width: {cw.config.bits}") + + +def pack_embedding_for_cuda( + module: nn.Module, weights: dict[str, CanonicalQuantizedWeight] +) -> None: + """Pack a quantized ``nn.Embedding`` for CUDA. + + Uses ``IntxUnpackedToInt8Tensor`` which supports embedding gather. + Only INT8 is supported — ``Int4TilePackedTo4dTensor`` does not + implement the embedding op. + """ + cw = weights["weight"] + if cw.config.bits != 8: + raise ValueError( + f"Only 8-bit embedding quantization is supported on CUDA, " + f"got {cw.config.bits}-bit." + ) + module.weight = pack_int8_for_cuda(cw) + + +DEFAULT_CUDA_PACKERS: dict[type, ModulePackerFn] = { + nn.Linear: pack_linear_for_cuda, + nn.Embedding: pack_embedding_for_cuda, +} + + +# --------------------------------------------------------------------------- +# Load + pack (I/O wrapper) + + +def load_and_pack_for_cuda( + path: str, + model: nn.Module, + packers: dict[type, ModulePackerFn] | None = None, +) -> None: + """Read a quantized safetensors file and pack into ``model`` for CUDA. + + Thin wrapper: ``load`` + ``pack_model``. + """ + quantized, unquantized = load(path) + pack_model(model, quantized, unquantized, packers or DEFAULT_CUDA_PACKERS) diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py new file mode 100644 index 00000000000..0ebfd032681 --- /dev/null +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantize weights to canonical form. + +``quantize_weight`` quantizes a single tensor given a ``QuantConfig``, +dispatching to the appropriate algorithm based on ``config.method``: + + - ``"min_max"``: standard symmetric/asymmetric quantization via torchao's + ``choose_qparams_affine`` + ``quantize_affine``. Runs on CPU or CUDA. + - ``"hqq"``: Half-Quadratic Quantization — iteratively refines scales via + a proximal solver for better accuracy. ``symmetric=False`` optimizes both + scale and zero (requires CUDA). ``symmetric=True`` optimizes scale only + (CPU or CUDA). + +``quantize_model`` walks a model's parameters, applies a ``QuantRecipe``, +and returns two dicts: quantized weights as ``CanonicalQuantizedWeight`` +and unquantized weights as plain tensors. + +Both are model-agnostic — they work for any ``nn.Module`` and any weight +shape (2D linears, 3D fused-expert stacks, etc.). +""" + +import torch +import torch.nn as nn + +from .recipe import QuantConfig, QuantRecipe + +from .serialize import CanonicalQuantizedWeight + + +# --------------------------------------------------------------------------- +# Per-weight quantization + + +def _quantize_min_max( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Standard min/max quantization. Returns (int_data, scale, zero_point).""" + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + + if config.bits == 4: + qmin, qmax = (-8, 7) if config.symmetric else (0, 15) + elif config.bits == 8: + qmin, qmax = -128, 127 + else: + raise ValueError(f"Unsupported bits={config.bits}") + + mapping = MappingType.SYMMETRIC if config.symmetric else MappingType.ASYMMETRIC + block_size = tuple([1] * (weight.ndim - 1) + [config.group_size]) + + scale, zero_point = choose_qparams_affine( + weight.float(), + mapping, + block_size, + target_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + scale_dtype=torch.bfloat16, + zero_point_dtype=torch.bfloat16, + ) + int_data = quantize_affine( + weight.float(), + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + return int_data, scale, zero_point + + +def _quantize_hqq_asymmetric( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Full HQQ (asymmetric, optimizes scale + zero). Requires CUDA. + + Returns (int_data, scale, zero_point) in canonical layout. + """ + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_affine_hqq, + ) + + device = weight.device + if device.type != "cuda": + device = torch.device("cuda") + + W_q, scale, zero, _shape = _choose_qparams_and_quantize_affine_hqq( + weight, + nbits=config.bits, + group_size=config.group_size, + axis=1, + compute_dtype=torch.bfloat16, + device=str(device), + raw_output=True, + ) + + int_data = W_q.to(torch.int8) + scale = scale.to(torch.bfloat16).reshape(*weight.shape[:-1], -1) + zero = zero.to(torch.bfloat16).reshape(*weight.shape[:-1], -1) + + return int_data, scale, zero + + +def _quantize_hqq_symmetric( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Scale-only HQQ (symmetric, optimizes scale only). Runs on CPU or CUDA. + + Returns (int_data, scale, zero_point) where zero_point is all zeros. + """ + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_scale_only_hqq, + ) + + if config.bits == 4: + qmin, qmax = -8, 7 + elif config.bits == 8: + qmin, qmax = -128, 127 + else: + raise ValueError(f"Unsupported bits={config.bits}") + + # scale_only_hqq requires 2D. For 3D+, flatten → quantize → reshape. + orig_shape = weight.shape + weight_2d = weight.reshape(-1, weight.shape[-1]) if weight.ndim > 2 else weight + + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + weight_2d, + [1, config.group_size], + qmin, + qmax, + ) + + int_data = qdata.to(torch.int8).reshape(orig_shape) + scale = scale.to(torch.bfloat16).reshape(*orig_shape[:-1], -1) + zero_point = torch.zeros_like(scale) + + return int_data, scale, zero_point + + +def quantize_weight( + weight: torch.Tensor, + config: QuantConfig, +) -> CanonicalQuantizedWeight: + """Quantize ``weight`` to canonical form. + + Dispatches to the algorithm specified by ``config.method``. The input is + processed in float32 internally for numerical stability. Does NOT pad or + pack for any backend. + """ + if config.method == "min_max": + int_data, scale, zero_point = _quantize_min_max(weight, config) + elif config.method == "hqq": + if config.symmetric: + int_data, scale, zero_point = _quantize_hqq_symmetric(weight, config) + else: + int_data, scale, zero_point = _quantize_hqq_asymmetric(weight, config) + else: + raise ValueError( + f"Unknown quantization method: {config.method!r}. " + f"Supported: 'min_max', 'hqq'." + ) + + # Normalize 4-bit to unsigned [0, 15] for uniform storage and nibble + # packing. Symmetric min_max produces [-8, 7]; shift to [0, 15]. + # HQQ already produces [0, 15] (asymmetric internally). + if config.bits == 4 and config.symmetric: + int_data = int_data + 8 + + return CanonicalQuantizedWeight( + qdata=int_data.to(torch.int8), + scale=scale.to(torch.bfloat16), + zero=zero_point.to(torch.bfloat16) if not config.symmetric else None, + config=config, + ) + + +# --------------------------------------------------------------------------- +# Per-model quantization + + +def quantize_model( + model: nn.Module, + recipe: QuantRecipe, + dtype: torch.dtype = torch.bfloat16, +) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: + """Walk model parameters + persistent buffers, apply recipe. + + For each parameter matched by a recipe rule: quantize to canonical. + Parameters that match ``None`` (skip) rules and persistent buffers go + into the unquantized dict (cast to ``dtype``). Non-persistent buffers + (KV cache, RoPE tables, etc.) are excluded. + + Returns ``(quantized, unquantized)`` dicts keyed by FQN. + """ + quantized: dict[str, CanonicalQuantizedWeight] = {} + unquantized: dict[str, torch.Tensor] = {} + persistent_keys = set(model.state_dict().keys()) + + n_params = sum(1 for _ in model.named_parameters()) + for i, (fqn, param) in enumerate(model.named_parameters()): + config = recipe.get_config(fqn) + if config is None: + unquantized[fqn] = param.data.to(dtype) + else: + quantized[fqn] = quantize_weight(param.data, config) + print(f" Quantized {i + 1}/{n_params}: {fqn}", end="\r") + print() + + for fqn, buf in model.named_buffers(): + if fqn in persistent_keys and fqn not in quantized: + unquantized[fqn] = buf.data + + return quantized, unquantized diff --git a/examples/models/gemma4_31b/quant/recipe.py b/examples/models/gemma4_31b/quant/recipe.py new file mode 100644 index 00000000000..49294c9b579 --- /dev/null +++ b/examples/models/gemma4_31b/quant/recipe.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantization recipe: declares what to quantize and how. + +A ``QuantRecipe`` is an ordered list of ``QuantRule`` objects matched against +weight FQNs. First match wins. The recipe says nothing about packing format, +tensor subclass, or target backend. +""" + +import re +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass(frozen=True) +class QuantConfig: + """Per-weight quantization parameters.""" + + bits: int # 4, 8 + group_size: int # 32, 64, 128 + symmetric: bool # True = no zero point + method: str # "min_max" | "hqq" + + +@dataclass +class QuantRule: + """A single recipe rule: regex pattern + config + optional layer filter.""" + + pattern: str # regex matched against weight FQN + config: Optional[QuantConfig] # None = skip (leave unquantized) + layers: Optional[set[int]] = field(default=None, repr=False) # None = all layers + + +@dataclass +class QuantRecipe: + """Ordered list of rules. First match wins.""" + + rules: list[QuantRule] + + def get_config(self, fqn: str) -> Optional[QuantConfig]: + """Return the ``QuantConfig`` for a weight FQN, or ``None`` to skip.""" + layer_idx = self._extract_layer_idx(fqn) + for rule in self.rules: + if rule.layers is not None: + if layer_idx is None or layer_idx not in rule.layers: + continue + if re.fullmatch(rule.pattern, fqn): + return rule.config + return None + + @staticmethod + def _extract_layer_idx(fqn: str) -> Optional[int]: + m = re.search(r"layers\.(\d+)\.", fqn) + return int(m.group(1)) if m else None diff --git a/examples/models/gemma4_31b/quant/serialize.py b/examples/models/gemma4_31b/quant/serialize.py new file mode 100644 index 00000000000..5996599ad90 --- /dev/null +++ b/examples/models/gemma4_31b/quant/serialize.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Serialize and persist quantized weights. + +Two layers: + + - **serialize / deserialize** — convert between ``CanonicalQuantizedWeight`` + objects and plain tensors + JSON metadata. Pure logic, no I/O. The output + is a ``(tensors_dict, metadata_dict)`` pair that any file writer can + consume. + - **save / load** — write/read the serialized form to/from safetensors on + disk. Thin I/O wrappers around ``safetensors.save_file`` / + ``safetensors.safe_open``. + +For 4-bit weights, qdata is nibble-packed (two values per byte) during +serialization to keep file size at ~0.5 bytes/param. +""" + +import json +from dataclasses import dataclass +from typing import Optional + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from .recipe import QuantConfig + +# Bump when the on-disk layout changes in a backward-incompatible way +# (e.g., different nibble-pack convention, renamed keys, new required fields). +# The loader rejects files with an unsupported version rather than silently +# producing corrupt data. +FORMAT_VERSION = "1" +_SUPPORTED_VERSIONS = {FORMAT_VERSION} + + +@dataclass +class CanonicalQuantizedWeight: + """Packing-free quantized weight representation. + + ``qdata`` int8 values: [0, 15] for 4-bit (both symmetric and asymmetric + are stored as unsigned after shifting), [-128, 127] for 8-bit. + ``scale`` bf16 per-group scales, shape ``[*weight_shape[:-1], K // group_size]``. + ``zero`` bf16 per-group zero points (``None`` when symmetric). + ``config`` the ``QuantConfig`` that produced this. + """ + + qdata: torch.Tensor + scale: torch.Tensor + zero: Optional[torch.Tensor] + config: QuantConfig + + +# --------------------------------------------------------------------------- +# Nibble packing for 4-bit on-disk storage. +# +# Two 4-bit values are packed into one byte to halve file size. The +# convention is: even-indexed values go into the LOW nibble (bits 0-3), +# odd-indexed values go into the HIGH nibble (bits 4-7). +# +# values: [v0, v1, v2, v3, ...] (each in [0, 15]) +# packed: [v0 | (v1 << 4), v2 | (v3 << 4), ...] +# byte 0: bits 0-3 = v0, bits 4-7 = v1 +# +# To unpack: low = byte & 0x0F, high = (byte >> 4) & 0x0F. +# +# This matches the Triton fused_moe kernel's unpack convention +# ((byte >> (k%2)*4) & 0xF) and Qwen's _quantize_experts_int4 packing. +# Note: tinygemm uses the OPPOSITE convention (even=HIGH, odd=LOW) — the +# CUDA packer in pack_cuda.py handles that conversion separately. + + +def _nibble_pack(qdata: torch.Tensor) -> torch.Tensor: + """Pack int8 values (each in [0, 15]) into half the last dim. + + Even-indexed values → low nibble, odd-indexed → high nibble. + """ + assert qdata.shape[-1] % 2 == 0, f"Last dim must be even, got {qdata.shape}" + low = qdata[..., ::2].to(torch.uint8) + high = qdata[..., 1::2].to(torch.uint8) + return (low | (high << 4)).to(torch.int8).contiguous() + + +def _nibble_unpack(packed: torch.Tensor, orig_last_dim: int) -> torch.Tensor: + """Unpack nibble-packed int8 → original last dim. + + Low nibble (bits 0-3) → even indices, high nibble (bits 4-7) → odd indices. + """ + p = packed.to(torch.uint8) + low = (p & 0x0F).to(torch.int8) + high = ((p >> 4) & 0x0F).to(torch.int8) + return torch.stack([low, high], dim=-1).reshape(*packed.shape[:-1], orig_last_dim) + + +# --------------------------------------------------------------------------- +# Serialize / deserialize (pure logic, no I/O) + + +def serialize( + quantized: dict[str, CanonicalQuantizedWeight], + unquantized: dict[str, torch.Tensor], +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Convert quantized + unquantized weights to plain tensors + metadata. + + Returns ``(tensors, header)`` ready for any file writer. Quantized + weights become ``{fqn}.qdata``, ``{fqn}.scale``, optionally + ``{fqn}.zero``. For 4-bit, qdata is nibble-packed. + """ + tensors: dict[str, torch.Tensor] = {} + quant_meta: dict[str, dict] = {} + + for fqn, cw in quantized.items(): + qdata = cw.qdata + if cw.config.bits == 4: + qdata = _nibble_pack(qdata) + tensors[f"{fqn}.qdata"] = qdata.contiguous() + tensors[f"{fqn}.scale"] = cw.scale.contiguous() + if cw.zero is not None: + tensors[f"{fqn}.zero"] = cw.zero.contiguous() + quant_meta[fqn] = { + "bits": cw.config.bits, + "group_size": cw.config.group_size, + "symmetric": cw.config.symmetric, + "method": cw.config.method, + "shape": list(cw.qdata.shape), + } + + for fqn, tensor in unquantized.items(): + tensors[fqn] = tensor.contiguous() + + header = {"format_version": FORMAT_VERSION} + if quant_meta: + header["quant"] = json.dumps(quant_meta) + + return tensors, header + + +def deserialize( + tensors: dict[str, torch.Tensor], + header: dict[str, str], +) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: + """Reconstruct quantized + unquantized weights from plain tensors + metadata. + + Inverse of ``serialize``. Returns ``(quantized, unquantized)`` dicts. + """ + version = header.get("format_version", "1") + if version not in _SUPPORTED_VERSIONS: + raise ValueError( + f"Unsupported format version {version!r}. " + f"This code supports {sorted(_SUPPORTED_VERSIONS)}. " + f"Update the quant package or re-quantize the model." + ) + + quant_meta = json.loads(header.get("quant", "{}")) + + quantized: dict[str, CanonicalQuantizedWeight] = {} + consumed_keys: set[str] = set() + + for fqn, meta in quant_meta.items(): + config = QuantConfig( + bits=meta["bits"], + group_size=meta["group_size"], + symmetric=meta["symmetric"], + method=meta["method"], + ) + qdata = tensors[f"{fqn}.qdata"] + consumed_keys.add(f"{fqn}.qdata") + + original_shape = meta["shape"] + if config.bits == 4: + qdata = _nibble_unpack(qdata, original_shape[-1]) + + scale = tensors[f"{fqn}.scale"] + consumed_keys.add(f"{fqn}.scale") + + zero = tensors.get(f"{fqn}.zero") + if zero is not None: + consumed_keys.add(f"{fqn}.zero") + + quantized[fqn] = CanonicalQuantizedWeight( + qdata=qdata, scale=scale, zero=zero, config=config + ) + + unquantized = {k: v for k, v in tensors.items() if k not in consumed_keys} + + return quantized, unquantized + + +# --------------------------------------------------------------------------- +# Save / load (I/O wrappers) + + +def save( + quantized: dict[str, CanonicalQuantizedWeight], + unquantized: dict[str, torch.Tensor], + path: str, +) -> int: + """Serialize and write to safetensors. Returns the number of tensors written.""" + tensors, header = serialize(quantized, unquantized) + save_file(tensors, path, metadata=header) + return len(tensors) + + +def load( + path: str, +) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: + """Read safetensors and deserialize. Returns ``(quantized, unquantized)``.""" + with safe_open(path, framework="pt", device="cpu") as f: + header = f.metadata() + tensors = {k: f.get_tensor(k) for k in f.keys()} + return deserialize(tensors, header) diff --git a/examples/models/gemma4_31b/quant/test_pack_cuda.py b/examples/models/gemma4_31b/quant/test_pack_cuda.py new file mode 100644 index 00000000000..5a20d02998b --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_pack_cuda.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/pack_cuda.py. Requires CUDA.""" + +import os +import tempfile +import unittest + +import torch +import torch.nn as nn + +from .pack_cuda import ( + DEFAULT_CUDA_PACKERS, + load_and_pack_for_cuda, + pack_embedding_for_cuda, + pack_int4_for_cuda, + pack_int8_for_cuda, + pack_linear_for_cuda, + pack_model, +) +from .quantize import quantize_weight +from .recipe import QuantConfig + +from .serialize import save + + +class TestPackInt4ForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_symmetric_works(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) + self.assertEqual(pack_int4_for_cuda(cw).shape, torch.Size([128, 256])) + + def test_rejects_1d(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(1, 128, dtype=torch.bfloat16), config) + cw.qdata = cw.qdata.squeeze(0) + with self.assertRaises(AssertionError): + pack_int4_for_cuda(cw) + + def test_rejects_8bit(self): + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + with self.assertRaises(AssertionError): + pack_int4_for_cuda(cw) + + def test_different_group_sizes(self): + for gs in (32, 64, 128): + with self.subTest(group_size=gs): + config = QuantConfig( + bits=4, group_size=gs, symmetric=False, method="min_max" + ) + cw = quantize_weight( + torch.randn(128, 256, dtype=torch.bfloat16), config + ) + self.assertEqual(pack_int4_for_cuda(cw).shape, torch.Size([128, 256])) + + def test_matmul_approximates_original(self): + """Packed weight produces matmul output close to the original.""" + torch.manual_seed(0) + # Use dimensions already aligned to tinygemm requirements + # (K multiple of 1024, N multiple of 8) to avoid padding effects. + weight = torch.randn(256, 1024, dtype=torch.bfloat16) + x = torch.randn(1, 1024, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int4_for_cuda(cw) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_symmetric_matmul_approximates_original(self): + """Symmetric 4-bit (e.g. HQQ) packs correctly for tinygemm.""" + torch.manual_seed(0) + weight = torch.randn(256, 1024, dtype=torch.bfloat16) + x = torch.randn(1, 1024, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int4_for_cuda(cw) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + +class TestPackInt8ForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_rejects_4bit(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + with self.assertRaises(AssertionError): + pack_int8_for_cuda(cw) + + def test_matmul_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(256, 128, dtype=torch.bfloat16) + x = torch.randn(1, 128, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int8_for_cuda(cw) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_per_axis_gather_approximates_original(self): + """Per-axis INT8 (group_size == K) works for embedding gather.""" + torch.manual_seed(0) + weight = torch.randn(1000, 64, dtype=torch.bfloat16) + ids = torch.tensor([0, 1, 42, 500, 999]) + + original = weight[ids] + + config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int8_for_cuda(cw) + + emb = nn.Embedding(1000, 64) + emb.weight = nn.Parameter(packed, requires_grad=False) + emb.to("cuda") + packed_out = emb(ids.cuda()) + + rel_error = ( + packed_out.cpu().float() - original.float() + ).abs().mean() / original.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + +class TestPackLinearForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_4bit_modifies_module_in_place(self): + module = nn.Linear(128, 256, bias=False) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(256, 128, dtype=torch.bfloat16), config) + pack_linear_for_cuda(module, {"weight": cw}) + self.assertEqual(module.weight.device.type, "cpu") + self.assertEqual(module.weight.shape, torch.Size([256, 128])) + + def test_8bit_modifies_module_in_place(self): + module = nn.Linear(128, 64, bias=False) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + pack_linear_for_cuda(module, {"weight": cw}) + self.assertEqual(module.weight.shape, torch.Size([64, 128])) + + +class TestPackEmbeddingForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_gather_approximates_original(self): + """INT8 quantized embedding gather matches bf16 gather.""" + torch.manual_seed(0) + weight = torch.randn(1000, 64, dtype=torch.bfloat16) + ids = torch.tensor([0, 1, 42, 500, 999]) + + original = weight[ids] + + config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + + module = nn.Embedding(1000, 64) + pack_embedding_for_cuda(module, {"weight": cw}) + module.to("cuda") + packed_out = module(ids.cuda()) + + rel_error = ( + packed_out.cpu().float() - original.float() + ).abs().mean() / original.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_rejects_4bit(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + module = nn.Embedding(100, 64) + with self.assertRaises(ValueError): + pack_embedding_for_cuda(module, {"weight": cw}) + + +class TestLoadAndPackForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_pack_model_in_memory(self): + """pack_model works with in-memory dicts (no file I/O).""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + unq = {"norm.weight": torch.randn(64, dtype=torch.bfloat16)} + + with torch.device("meta"): + model = nn.ModuleDict( + { + "proj": nn.Linear(128, 64, bias=False), + "norm": nn.LayerNorm(64, bias=False), + } + ) + pack_model(model, {"proj.weight": cw}, unq, DEFAULT_CUDA_PACKERS) + + self.assertEqual(model.proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.norm.weight.shape, torch.Size([64])) + + def test_pack_model_mixed_precision(self): + """pack_model handles 4-bit and 8-bit weights in the same model.""" + q4_config = QuantConfig( + bits=4, group_size=32, symmetric=False, method="min_max" + ) + q8_config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw4 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q4_config) + cw8 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q8_config) + + with torch.device("meta"): + model = nn.ModuleDict( + { + "q_proj": nn.Linear(128, 64, bias=False), + "v_proj": nn.Linear(128, 64, bias=False), + } + ) + pack_model( + model, + {"q_proj.weight": cw4, "v_proj.weight": cw8}, + {}, + DEFAULT_CUDA_PACKERS, + ) + + self.assertEqual(model.q_proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.v_proj.weight.shape, torch.Size([64, 128])) + # Verify different subclass types + self.assertNotEqual( + type(model.q_proj.weight.data).__name__, + type(model.v_proj.weight.data).__name__, + ) + + def test_dispatches_by_module_type(self): + """load_and_pack_for_cuda reads from disk and dispatches.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"proj.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict({"proj": nn.Linear(128, 64, bias=False)}) + load_and_pack_for_cuda(path, model2) + + self.assertEqual(model2.proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model2.proj.weight.device.type, "cpu") + + def test_unknown_module_type_raises(self): + """Unregistered module types get a clear error.""" + + class CustomModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(32, 64)) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"custom.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict({"custom": CustomModule()}) + with self.assertRaises(ValueError) as ctx: + load_and_pack_for_cuda(path, model2) + self.assertIn("CustomModule", str(ctx.exception)) + + def test_missing_weight_raises(self): + """A meta-device parameter after loading means the checkpoint was incomplete.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + # Only save weight for 'a', not 'b' + save({"a.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict( + { + "a": nn.Linear(64, 32, bias=False), + "b": nn.Linear(64, 32, bias=False), + } + ) + with self.assertRaises(RuntimeError) as ctx: + load_and_pack_for_cuda(path, model2) + self.assertIn("b.weight", str(ctx.exception)) + + def test_custom_packer_via_dict(self): + """Models can extend the packer dict with custom module types.""" + call_log = [] + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(32, 64)) + + def my_packer(module, weights): + call_log.append(("my_packer", list(weights.keys()))) + cw = weights["weight"] + module.weight = nn.Parameter( + cw.qdata.to(torch.bfloat16), requires_grad=False + ) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + custom_packers = {**DEFAULT_CUDA_PACKERS, MyModule: my_packer} + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"m.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict({"m": MyModule()}) + load_and_pack_for_cuda(path, model2, packers=custom_packers) + + self.assertEqual(len(call_log), 1) + self.assertEqual(call_log[0], ("my_packer", ["weight"])) + self.assertEqual(model2.m.weight.device.type, "cpu") + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/test_quantize.py b/examples/models/gemma4_31b/quant/test_quantize.py new file mode 100644 index 00000000000..214a22f718b --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_quantize.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/quantize.py. + +Tests the public API: ``quantize_weight`` and ``quantize_model``. Organized +by resource requirement (CPU vs CUDA), not by internal codepath. +""" + +import unittest + +import torch +import torch.nn as nn +from parameterized import parameterized + +from .quantize import quantize_model, quantize_weight +from .recipe import QuantConfig, QuantRecipe, QuantRule + + +# --------------------------------------------------------------------------- +# quantize_weight — CPU (uses min_max; tests the output contract) + + +class TestQuantizeWeight(unittest.TestCase): + @parameterized.expand( + [ + ("4bit_asym", 4, 32, False, (64, 128), 0, 15), + ("4bit_sym", 4, 32, True, (64, 128), 0, 15), + ("4bit_gs64", 4, 64, False, (32, 128), 0, 15), + ("8bit_sym", 8, 32, True, (32, 64), -128, 127), + ("3d_expert", 4, 32, False, (8, 64, 128), 0, 15), + ] + ) + def test_output_structure(self, _name, bits, gs, sym, shape, qmin, qmax): + config = QuantConfig(bits=bits, group_size=gs, symmetric=sym, method="min_max") + cw = quantize_weight(torch.randn(*shape, dtype=torch.bfloat16), config) + + self.assertEqual(cw.qdata.shape, shape) + self.assertEqual(cw.qdata.dtype, torch.int8) + self.assertEqual(cw.scale.shape, (*shape[:-1], shape[-1] // gs)) + self.assertGreaterEqual(cw.qdata.min().item(), qmin) + self.assertLessEqual(cw.qdata.max().item(), qmax) + + if sym: + self.assertIsNone(cw.zero) + else: + self.assertIsNotNone(cw.zero) + self.assertEqual(cw.zero.shape, cw.scale.shape) + + self.assertEqual(cw.config, config) + + def test_fp32_input(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.float32), config) + self.assertEqual(cw.qdata.shape, (32, 64)) + + def test_dequant_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.float32) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(weight, config) + scale = cw.scale.float().repeat_interleave(config.group_size, dim=-1) + zero = cw.zero.float().repeat_interleave(config.group_size, dim=-1) + dequant = (cw.qdata.float() - zero) * scale + rel_error = (dequant - weight).abs().mean() / weight.abs().mean() + self.assertLess(rel_error.item(), 0.15) + + @parameterized.expand( + [ + ("unknown_method", QuantConfig(4, 32, False, "bogus"), "bogus"), + ("unsupported_bits", QuantConfig(3, 32, False, "min_max"), None), + ] + ) + def test_invalid_config_raises(self, _name, config, expected_substr): + with self.assertRaises(ValueError) as ctx: + quantize_weight(torch.randn(32, 64), config) + if expected_substr: + self.assertIn(expected_substr, str(ctx.exception)) + + +# --------------------------------------------------------------------------- +# quantize_weight — CUDA (HQQ-specific behavior only) + + +class TestQuantizeWeightHQQ(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required for HQQ") + + def test_dequant_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.float32, device="cuda") + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") + cw = quantize_weight(weight, config) + scale = cw.scale.cpu().float().repeat_interleave(config.group_size, dim=-1) + zero = cw.zero.cpu().float().repeat_interleave(config.group_size, dim=-1) + dequant = (cw.qdata.cpu().float() - zero) * scale + rel_error = (dequant - weight.cpu()).abs().mean() / weight.cpu().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_symmetric_scale_only(self): + """symmetric=True dispatches to scale-only HQQ (no zero).""" + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + self.assertIsNone(cw.zero) + self.assertGreaterEqual(cw.qdata.min().item(), 0) + self.assertLessEqual(cw.qdata.max().item(), 15) + + def test_cpu_input_accepted(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + self.assertEqual(cw.qdata.shape, (32, 64)) + + +# --------------------------------------------------------------------------- +# quantize_model + + +class TestQuantizeModel(unittest.TestCase): + def test_applies_recipe(self): + model = nn.ModuleDict( + { + "embed": nn.Embedding(32, 16), + "proj": nn.Linear(16, 32, bias=False), + "norm": nn.LayerNorm(32), + } + ) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + p.data.normal_(0, 0.02) + + recipe = QuantRecipe( + rules=[ + QuantRule(r"embed\.weight", None), + QuantRule(r"norm\.weight", None), + QuantRule(r".*\.weight", QuantConfig(4, 16, False, "min_max")), + ] + ) + + quantized, unquantized = quantize_model(model, recipe) + + self.assertIn("proj.weight", quantized) + self.assertEqual(quantized["proj.weight"].qdata.shape, (32, 16)) + self.assertIn("embed.weight", unquantized) + self.assertIn("norm.weight", unquantized) + self.assertNotIn("embed.weight", quantized) + self.assertNotIn("norm.weight", quantized) + + def test_persistent_buffers_included(self): + model = nn.Module() + model.weight = nn.Parameter(torch.randn(16, 32, dtype=torch.bfloat16)) + model.register_buffer("scalar", torch.ones(1)) + model.register_buffer("temp", torch.zeros(4), persistent=False) + + recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) + _, unquantized = quantize_model(model, recipe) + + self.assertIn("scalar", unquantized) + self.assertNotIn("temp", unquantized) + + def test_unquantized_cast_to_dtype(self): + model = nn.ModuleDict({"proj": nn.Linear(16, 8, bias=False)}) + model.proj.weight.data = torch.randn(8, 16, dtype=torch.float32) + + recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) + _, unquantized = quantize_model(model, recipe, dtype=torch.float16) + + self.assertEqual(unquantized["proj.weight"].dtype, torch.float16) + + def test_empty_model(self): + quantized, unquantized = quantize_model(nn.Module(), QuantRecipe(rules=[])) + self.assertEqual(len(quantized), 0) + self.assertEqual(len(unquantized), 0) + + def test_all_quantized(self): + model = nn.ModuleDict({"a": nn.Linear(32, 16, bias=False)}) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + p.data.normal_(0, 0.02) + + config = QuantConfig(bits=4, group_size=16, symmetric=False, method="min_max") + quantized, unquantized = quantize_model( + model, QuantRecipe(rules=[QuantRule(r".*", config)]) + ) + self.assertEqual(len(quantized), 1) + self.assertIn("a.weight", quantized) + self.assertEqual(len(unquantized), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/test_recipe.py b/examples/models/gemma4_31b/quant/test_recipe.py new file mode 100644 index 00000000000..5b7afd992e0 --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_recipe.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/recipe.py. CPU only — no CUDA, no model, no torchao.""" + +import unittest + +from parameterized import parameterized + +from .recipe import QuantConfig, QuantRecipe, QuantRule + +_Q4 = QuantConfig(4, 32, True, "min_max") +_Q8 = QuantConfig(8, 32, True, "min_max") + + +class TestQuantRecipeGetConfig(unittest.TestCase): + """Tests for ``QuantRecipe.get_config`` — the core matching logic.""" + + @parameterized.expand( + [ + ( + "first_match_wins", + [QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)], + "layers.0.self_attn.v_proj.weight", + 8, + ), + ( + "fallthrough_to_catchall", + [QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)], + "layers.0.self_attn.q_proj.weight", + 4, + ), + ( + "none_rule_skips", + [ + QuantRule(r"embed_tokens\.weight", None), + QuantRule(r".*\.weight", _Q4), + ], + "embed_tokens.weight", + None, + ), + ( + "unmatched_returns_none", + [QuantRule(r"foo", _Q4)], + "bar.weight", + None, + ), + ( + "empty_recipe", + [], + "anything", + None, + ), + ( + "fullmatch_not_partial", + [QuantRule(r"foo", _Q4)], + "foo.bar", + None, + ), + ( + "fullmatch_exact", + [QuantRule(r"foo", _Q4)], + "foo", + 4, + ), + ] + ) + def test_get_config(self, _name, rules, fqn, expected_bits): + recipe = QuantRecipe(rules=rules) + config = recipe.get_config(fqn) + if expected_bits is None: + self.assertIsNone(config) + else: + self.assertEqual(config.bits, expected_bits) + + +class TestQuantRecipeLayerFilter(unittest.TestCase): + """Tests for the ``layers`` field on ``QuantRule``.""" + + def test_layer_filter(self): + edge = set(range(5)) | set(range(55, 60)) + recipe = QuantRecipe( + rules=[ + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _Q8, layers=edge), + QuantRule(r".*\.weight", _Q4), + ] + ) + # Edge v_proj → 8-bit + self.assertEqual(recipe.get_config("layers.0.self_attn.v_proj.weight").bits, 8) + self.assertEqual(recipe.get_config("layers.58.self_attn.v_proj.weight").bits, 8) + # Middle v_proj → falls through → 4-bit + self.assertEqual(recipe.get_config("layers.30.self_attn.v_proj.weight").bits, 4) + # q_proj always 4-bit + self.assertEqual(recipe.get_config("layers.0.self_attn.q_proj.weight").bits, 4) + # Non-layer FQN skips layer-filtered rule, hits catch-all + self.assertEqual(recipe.get_config("lm_head.weight").bits, 4) + + def test_layer_filter_with_none_config(self): + """Skip rule scoped to specific layers.""" + recipe = QuantRecipe( + rules=[ + QuantRule(r".*\.weight", None, layers={0}), + QuantRule(r".*\.weight", _Q4), + ] + ) + self.assertIsNone(recipe.get_config("layers.0.mlp.gate_proj.weight")) + self.assertEqual(recipe.get_config("layers.1.mlp.gate_proj.weight").bits, 4) + + +class TestProductionRecipes(unittest.TestCase): + """Regression tests for the production recipes in quantize_and_save.py.""" + + def test_default_recipe(self): + from executorch.examples.models.gemma4_31b.quantize_and_save import ( + GEMMA4_31B_DEFAULT_RECIPE, + ) + + r = GEMMA4_31B_DEFAULT_RECIPE + self.assertIsNone(r.get_config("layers.0.input_layernorm.weight")) + self.assertIsNone(r.get_config("layers.5.self_attn.q_norm.weight")) + self.assertIsNone(r.get_config("norm.weight")) + embed_cfg = r.get_config("embed_tokens.weight") + self.assertEqual(embed_cfg.bits, 8) + self.assertEqual(embed_cfg.group_size, 5376) + for fqn in ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.down_proj.weight", + "lm_head.weight", + ): + cfg = r.get_config(fqn) + self.assertEqual(cfg.bits, 4, fqn) + self.assertEqual(cfg.method, "min_max", fqn) + + def test_sensitive_recipe(self): + from executorch.examples.models.gemma4_31b.quantize_and_save import ( + GEMMA4_31B_SENSITIVE_RECIPE, + ) + + r = GEMMA4_31B_SENSITIVE_RECIPE + self.assertIsNone(r.get_config("layers.0.input_layernorm.weight")) + embed_cfg = r.get_config("embed_tokens.weight") + self.assertEqual(embed_cfg.bits, 8) + self.assertEqual(embed_cfg.group_size, 5376) + # Edge v_proj/down_proj → int8 + self.assertEqual(r.get_config("layers.0.self_attn.v_proj.weight").bits, 8) + self.assertEqual(r.get_config("layers.0.mlp.down_proj.weight").bits, 8) + self.assertEqual(r.get_config("layers.58.self_attn.v_proj.weight").bits, 8) + # Middle v_proj/down_proj → int4 + self.assertEqual(r.get_config("layers.30.self_attn.v_proj.weight").bits, 4) + self.assertEqual(r.get_config("layers.30.mlp.down_proj.weight").bits, 4) + # q_proj always int4 + self.assertEqual(r.get_config("layers.0.self_attn.q_proj.weight").bits, 4) + self.assertEqual(r.get_config("layers.30.self_attn.q_proj.weight").bits, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/test_serialize.py b/examples/models/gemma4_31b/quant/test_serialize.py new file mode 100644 index 00000000000..d84e53d0a0b --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_serialize.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/serialize.py — data format and I/O only. + +Tests nibble pack/unpack and save/load. Does NOT test +quantize_weight (that lives in test_quantize.py). Save/load tests use +hand-built CanonicalQuantizedWeight fixtures to avoid coupling to the +quantizer. +""" + +import json +import os +import tempfile +import unittest + +import torch +from safetensors import safe_open + +from .recipe import QuantConfig + +from .serialize import ( + _nibble_pack, + _nibble_unpack, + CanonicalQuantizedWeight, + deserialize, + load, + save, + serialize, +) + + +def _make_cqw( + shape: tuple[int, ...], + config: QuantConfig, +) -> CanonicalQuantizedWeight: + """Build a CanonicalQuantizedWeight with random data (no actual quantization).""" + K = shape[-1] + n_groups = K // config.group_size + scale_shape = (*shape[:-1], n_groups) + + if config.bits == 4: + qdata = torch.randint(0, 16, shape, dtype=torch.int8) + else: + qdata = torch.randint(-128, 128, shape, dtype=torch.int8) + + return CanonicalQuantizedWeight( + qdata=qdata, + scale=torch.randn(scale_shape, dtype=torch.bfloat16), + zero=( + torch.randn(scale_shape, dtype=torch.bfloat16) + if not config.symmetric + else None + ), + config=config, + ) + + +# --------------------------------------------------------------------------- +# Nibble pack / unpack + + +class TestNibblePack(unittest.TestCase): + def test_roundtrip(self): + qdata = torch.randint(0, 16, (8, 64), dtype=torch.int8) + packed = _nibble_pack(qdata) + self.assertEqual(packed.shape, (8, 32)) + self.assertTrue(torch.equal(qdata, _nibble_unpack(packed, 64))) + + def test_rejects_odd_last_dim(self): + with self.assertRaises(AssertionError): + _nibble_pack(torch.zeros(4, 33, dtype=torch.int8)) + + def test_3d(self): + """Nibble pack works for 3D tensors (MoE expert weights).""" + qdata = torch.randint(0, 16, (4, 32, 64), dtype=torch.int8) + packed = _nibble_pack(qdata) + self.assertEqual(packed.shape, (4, 32, 32)) + self.assertTrue(torch.equal(qdata, _nibble_unpack(packed, 64))) + + +# --------------------------------------------------------------------------- +# save / load + + +class TestSerializeDeserialize(unittest.TestCase): + """Pure logic layer — no disk I/O.""" + + def test_roundtrip(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((64, 128), config) + unq = {"embed": torch.randn(8, 8, dtype=torch.bfloat16)} + + tensors, header = serialize({"w": cw}, unq) + q, u = deserialize(tensors, header) + + self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) + self.assertTrue(torch.equal(cw.scale, q["w"].scale)) + self.assertTrue(torch.equal(cw.zero, q["w"].zero)) + self.assertEqual(cw.config, q["w"].config) + self.assertTrue(torch.equal(unq["embed"], u["embed"])) + + def test_rejects_unsupported_version(self): + tensors, header = serialize({}, {"w": torch.randn(4, 4)}) + header["format_version"] = "99" + with self.assertRaises(ValueError) as ctx: + deserialize(tensors, header) + self.assertIn("99", str(ctx.exception)) + + +class TestSaveLoad(unittest.TestCase): + """I/O layer — roundtrip through safetensors on disk.""" + + def test_roundtrip_asymmetric(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((64, 128), config) + unq = {"embed.weight": torch.randn(32, 64, dtype=torch.bfloat16)} + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, unq, path) + q, u = load(path) + + self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) + self.assertTrue(torch.equal(cw.scale, q["w"].scale)) + self.assertTrue(torch.equal(cw.zero, q["w"].zero)) + self.assertEqual(cw.config, q["w"].config) + self.assertTrue(torch.equal(unq["embed.weight"], u["embed.weight"])) + + def test_roundtrip_symmetric(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = _make_cqw((32, 64), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, {}, path) + q, _ = load(path) + + self.assertIsNone(q["w"].zero) + self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) + + def test_roundtrip_3d(self): + """3D quantized weights (MoE experts) roundtrip correctly.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((8, 64, 128), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"experts.w1": cw}, {}, path) + q, _ = load(path) + + self.assertTrue(torch.equal(cw.qdata, q["experts.w1"].qdata)) + self.assertEqual(q["experts.w1"].scale.shape, (8, 64, 4)) + + def test_4bit_nibble_packed_on_disk(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((64, 128), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, {}, path) + with safe_open(path, framework="pt", device="cpu") as f: + on_disk = f.get_tensor("w.qdata") + self.assertEqual(on_disk.shape, (64, 64)) + + def test_8bit_not_nibble_packed(self): + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = _make_cqw((32, 64), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, {}, path) + with safe_open(path, framework="pt", device="cpu") as f: + on_disk = f.get_tensor("w.qdata") + self.assertEqual(on_disk.shape, (32, 64)) # no packing for 8-bit + + def test_header_metadata(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((32, 64), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"foo.weight": cw}, {}, path) + with safe_open(path, framework="pt", device="cpu") as f: + meta = json.loads(f.metadata()["quant"]) + + self.assertIn("foo.weight", meta) + self.assertEqual(meta["foo.weight"]["bits"], 4) + self.assertEqual(meta["foo.weight"]["group_size"], 32) + self.assertFalse(meta["foo.weight"]["symmetric"]) + self.assertEqual(meta["foo.weight"]["method"], "min_max") + + def test_empty_quantized(self): + unq = {"w": torch.randn(8, 8, dtype=torch.bfloat16)} + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({}, unq, path) + q, u = load(path) + self.assertEqual(len(q), 0) + self.assertTrue(torch.equal(unq["w"], u["w"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quantize_and_save.py b/examples/models/gemma4_31b/quantize_and_save.py new file mode 100644 index 00000000000..7a9eb9900f2 --- /dev/null +++ b/examples/models/gemma4_31b/quantize_and_save.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantize Gemma 4 31B-IT and save as a quantized checkpoint. + +Produces a packing-agnostic safetensors file (int values + per-group scales + +JSON header) that can later be loaded and packed for any backend via +``quant.load()`` and ``quant.pack_model()``. + +No CUDA is needed — quantization runs on CPU. CUDA is only required at +load-and-pack time. + +Usage: + python quantize_and_save.py \\ + --model-dir ~/local/scripts/models/gemma-4-31B-it \\ + --output ./gemma4_31b_int4 \\ + --quant-recipe default +""" + +import argparse +import os +import shutil + +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + QuantConfig, + quantize_model, + QuantRecipe, + QuantRule, + save, +) + +# --------------------------------------------------------------------------- +# Production recipes for Gemma 4 31B. +# +# Layer sensitivity: +# - v_proj and down_proj are the most sensitive to quantization error +# (first/last quarter of layers especially so). +# - q_proj, k_proj, o_proj, gate_proj, up_proj tolerate 4-bit well. +# - embed_tokens is an index lookup — INT8 per-axis is nearly lossless. +# - Norms and layer_scalar are tiny and must stay unquantized. + +_INT4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") +_INT4_HQQ = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") +_INT8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") +_INT8_PER_AXIS = QuantConfig(bits=8, group_size=5376, symmetric=True, method="min_max") +_EDGE_LAYERS = set(range(15)) | set(range(45, 60)) + +GEMMA4_31B_DEFAULT_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.weight", _INT4), + ] +) + +GEMMA4_31B_SENSITIVE_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _INT8, layers=_EDGE_LAYERS), + QuantRule(r".*\.weight", _INT4_HQQ), + ] +) + +_RECIPES = { + "default": GEMMA4_31B_DEFAULT_RECIPE, + "sensitive": GEMMA4_31B_SENSITIVE_RECIPE, +} + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Quantize Gemma 4 31B-IT and save as a quantized checkpoint." + ) + parser.add_argument( + "--model-dir", + required=True, + help="HuggingFace Gemma 4 31B-IT model dir.", + ) + parser.add_argument( + "--output", + default="./gemma4_31b_int4", + help="Output directory.", + ) + parser.add_argument( + "--quant-recipe", + default="default", + choices=list(_RECIPES), + help="'default': int4 min_max linears + int8 per-axis embedding. " + "'sensitive': int8 for edge-layer v_proj/down_proj, int4 hqq elsewhere.", + ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend (the quantized checkpoint is backend-agnostic, " + "but this may influence default recipe selection in the future).", + ) + args = parser.parse_args() + + recipe = _RECIPES[args.quant_recipe] + + print("Loading checkpoint (lazy, shard-by-shard)...") + model, _ = Gemma4_31B.from_hf_checkpoint(args.model_dir) + + if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr(): + print("Untying embed_tokens / lm_head...") + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + print(f"Quantizing with recipe '{args.quant_recipe}'...") + quantized, unquantized = quantize_model(model, recipe) + + os.makedirs(args.output, exist_ok=True) + safetensors_path = os.path.join(args.output, "model.safetensors") + print("Saving quantized checkpoint...") + n_tensors = save(quantized, unquantized, safetensors_path) + + for filename in ("config.json", "tokenizer.json", "tokenizer_config.json"): + src = os.path.join(args.model_dir, filename) + if os.path.exists(src): + shutil.copy2(src, os.path.join(args.output, filename)) + + size_mb = os.path.getsize(safetensors_path) / (1024 * 1024) + print(f"Saved {n_tensors} tensors ({size_mb:.1f} MB) to {args.output}/") + print(f"Done. Use with: python export.py --prequantized {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/sampler.py b/examples/models/gemma4_31b/sampler.py new file mode 100644 index 00000000000..45e4e17887a --- /dev/null +++ b/examples/models/gemma4_31b/sampler.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""GPU-side Gumbel-max sampler. + +Mirrors ``examples/models/qwen3_5_moe/sampler.py``: a single-output sampler +that lets one exported program be re-driven with different temperatures +without re-export. ``temperature=None`` is a no-op (returns logits). +""" + +from typing import Optional + +import torch + + +def sample( + logits: torch.Tensor, + temperature: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Draw a single token per batch row using the Gumbel-max trick. + + Args: + logits: ``[B, V]`` float32 logits (already soft-capped if applicable). + temperature: 0-D or 1-D float tensor; clamped to >= 1e-6 so a 0 + temperature still works ("near-greedy"). When ``None`` the call + short-circuits and returns ``logits`` unchanged. + + Returns: + ``[B, 1]`` float32 token IDs (``argmax(logits/T + gumbel_noise)``), + or the unmodified logits when ``temperature`` is ``None``. + """ + if temperature is None: + return logits + + logits = logits / temperature.clamp(min=1e-6) + noise = torch.rand_like(logits) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return (logits + gumbel).argmax(dim=-1, keepdim=True).float() diff --git a/examples/models/gemma4_31b/test_cuda_pipeline.py b/examples/models/gemma4_31b/test_cuda_pipeline.py new file mode 100644 index 00000000000..faae59f160f --- /dev/null +++ b/examples/models/gemma4_31b/test_cuda_pipeline.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA-specific integration tests for the Gemma 4 31B-IT pipeline. + +Tests pack → inference → export on a tiny model using the CUDA backend. +Backend-agnostic tests (quantize, save, load) live in ``test_pipeline.py``. + +Requires CUDA. + +Usage: + python -m pytest examples/models/gemma4_31b/test_cuda_pipeline.py -v +""" + +import os +import tempfile +import unittest + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.export import ( + export_and_lower, + load_prequantized_model, +) +from executorch.examples.models.gemma4_31b.inference import _move_to_cuda, generate +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + DEFAULT_CUDA_PACKERS, + pack_model, + quantize_model, +) +from executorch.examples.models.gemma4_31b.test_pipeline import ( + build_hf_checkpoint, + DEFAULT_RECIPE, + MockTokenizer, + save_checkpoint, + TINY_CONFIG, +) + + +def _require_cuda(testcase: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + testcase.skipTest("CUDA required") + + +class TestCudaInference(unittest.TestCase): + def setUp(self): + _require_cuda(self) + + def test_generate(self): + """save → load → pack → generate (sampling + greedy).""" + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + model, config = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + _move_to_cuda(model, config) + model.eval() + tokenizer = MockTokenizer(TINY_CONFIG.vocab_size) + + torch.manual_seed(0) + out = generate(model, tokenizer, prompt="hi", max_new_tokens=5, temperature=1.0) + self.assertIsInstance(out, str) + ids_part = out[len("" + + +def config_dict() -> dict: + cfg = TINY_CONFIG + return { + "vocab_size": cfg.vocab_size, + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_hidden_layers": cfg.num_hidden_layers, + "num_attention_heads": cfg.num_attention_heads, + "num_key_value_heads": cfg.num_key_value_heads, + "head_dim": cfg.head_dim, + "num_global_key_value_heads": cfg.num_global_key_value_heads, + "global_head_dim": cfg.global_head_dim, + "attention_k_eq_v": cfg.attention_k_eq_v, + "rope_parameters": { + "sliding_attention": {"rope_theta": cfg.sliding_rope_theta}, + "full_attention": { + "rope_theta": cfg.full_rope_theta, + "partial_rotary_factor": cfg.full_partial_rotary_factor, + }, + }, + "rms_norm_eps": cfg.rms_norm_eps, + "hidden_activation": cfg.hidden_activation, + "final_logit_softcapping": cfg.final_logit_softcapping, + "tie_word_embeddings": cfg.tie_word_embeddings, + "sliding_window": cfg.sliding_window, + "layer_types": cfg.layer_types, + } + + +def build_random_tiny_model() -> Gemma4_31B: + torch.manual_seed(42) + model = Gemma4_31B(TINY_CONFIG) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + if p.device.type != "meta": + p.data.normal_(0, 0.02) + model.eval() + return model + + +def save_checkpoint(output_dir: str): + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + os.makedirs(output_dir, exist_ok=True) + save(quantized, unquantized, os.path.join(output_dir, "model.safetensors")) + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + +def build_hf_checkpoint(output_dir: str) -> None: + model = build_random_tiny_model() + sd = model.state_dict() + sd.pop("lm_head.weight", None) + hf_sd = {f"model.language_model.{k}": v.contiguous() for k, v in sd.items()} + save_file(hf_sd, os.path.join(output_dir, "model.safetensors")) + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + +# --------------------------------------------------------------------------- +# Tests (CPU only, no backend dependency) + + +class TestQuantizeSaveLoadRoundtrip(unittest.TestCase): + def test_roundtrip_preserves_weights(self): + """quantize → save → load recovers all weights and configs.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "model.safetensors") + save(quantized, unquantized, path) + q_loaded, u_loaded = load(path) + + self.assertEqual(set(quantized.keys()), set(q_loaded.keys())) + for fqn in quantized: + self.assertEqual(quantized[fqn].config, q_loaded[fqn].config) + self.assertTrue(torch.equal(quantized[fqn].qdata, q_loaded[fqn].qdata)) + self.assertTrue(torch.equal(quantized[fqn].scale, q_loaded[fqn].scale)) + + self.assertEqual(set(unquantized.keys()), set(u_loaded.keys())) + for fqn in unquantized: + self.assertTrue(torch.equal(unquantized[fqn], u_loaded[fqn])) + + def test_embedding_quantized_as_int8(self): + """embed_tokens is quantized to INT8 per-axis, not skipped.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + + self.assertIn("embed_tokens.weight", quantized) + self.assertNotIn("embed_tokens.weight", unquantized) + self.assertEqual(quantized["embed_tokens.weight"].config.bits, 8) + + def test_corrupted_checkpoint_missing_key(self): + """Renaming a key in the safetensors file makes it absent after load.""" + from safetensors import safe_open + + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + path = os.path.join(tmpdir, "model.safetensors") + + with safe_open(path, framework="pt", device="cpu") as f: + header = f.metadata() + tensors = {k: f.get_tensor(k) for k in f.keys()} + tensors["norm.BOGUS"] = tensors.pop("norm.weight") + save_file(tensors, path, metadata=header) + + q, u = load(path) + self.assertNotIn("norm.weight", u) + self.assertIn("norm.BOGUS", u) + + +if __name__ == "__main__": + unittest.main()