Skip to content

Commit 9108a5b

Browse files
committed
Ring-buffer KV cache, chunked prefill, INT8 embedding, and cleanup
- Sliding window layers use RingKVCache (2×window) instead of flat max_seq_len buffer, reducing KV cache memory for long sequences. - Prefill is capped to ring buffer size; the C++ runner chunks longer prompts automatically via get_max_prefill_chunk metadata. - Both recipes now quantize embed_tokens to INT8 per-axis (~1.4 GB savings vs bf16). Embedding packer uses IntxUnpackedToInt8Tensor which supports gather. - pack_model handles top-level FQNs (no parent module). - C++ runner aligned with Qwen patterns: #ifdef guards for non-CUDA builds, better weight_sharing error handling, cudaDeviceSynchronize between prefill and decode. - Test suite split into test_pipeline.py (CPU) and test_cuda_pipeline.py (CUDA) with shared fixtures. New chunked prefill correctness test. - Prequantized checkpoint available at huggingface.co/SocialLocalMobile/gemma-4-31B-it-HQQ-INT4. - Added Gemma 4 31B tests to cuda.yml CI workflow. - Cleaned up stale terminology, docstrings, and comments throughout.
1 parent f04e065 commit 9108a5b

11 files changed

Lines changed: 195 additions & 75 deletions

File tree

examples/models/gemma4_31b/README.md

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,36 @@ Two built-in recipes (see `quantize_and_save.py`):
3232
| `default` | INT4 min_max linears, INT8 per-axis embedding |
3333
| `sensitive` | INT8 for edge-layer v_proj/down_proj, INT4 hqq elsewhere, INT8 per-axis embedding |
3434

35-
## Quantize once
35+
## Prequantized checkpoint
36+
37+
A prequantized checkpoint (sensitive recipe) is available on HuggingFace:
38+
39+
```bash
40+
huggingface-cli download SocialLocalMobile/gemma-4-31B-it-HQQ-INT4 --local-dir gemma-4-31B-it-HQQ-INT4
41+
```
42+
43+
> **Note**: This checkpoint is intended for development and testing of the
44+
> ExecuTorch CUDA export pipeline. Output quality has not been formally
45+
> evaluated against the base model.
46+
47+
Use it directly with `--prequantized` in the export and inference scripts
48+
below — no need to run `quantize_and_save.py`.
49+
50+
## Quantize from scratch (optional)
51+
52+
To quantize from the original bf16 checkpoint instead, pass
53+
`--quant-recipe` to select a recipe (`default` or `sensitive`):
3654

3755
```bash
3856
python examples/models/gemma4_31b/quantize_and_save.py \
39-
--model-dir ~/local/scripts/models/gemma-4-31B-it \
57+
--model-dir /path/to/gemma-4-31B-it \
4058
--output ./gemma4_31b_int4 \
41-
--quant-recipe default
59+
--quant-recipe sensitive
4260
```
4361

44-
Writes `model.safetensors`, `config.json`, and
45-
`tokenizer.json` into `--output`.
62+
See [Quantization recipes](#quantization-recipes) above for details on each
63+
recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into
64+
`--output`.
4665

4766
## Export to ExecuTorch
4867

examples/models/gemma4_31b/export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
161161
strict=True,
162162
)
163163

164-
max_prefill = config.max_seq_len - 1
164+
# Cap prefill length to the ring-buffer KV cache size (2×sliding_window).
165+
# Longer prompts are chunked by the runner.
166+
max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)
165167
seq_dim = Dim("seq_len", min=2, max=max_prefill)
166168
print(f"Exporting prefill (T in [2, {max_prefill}])...")
167169
with torch.no_grad():
@@ -199,6 +201,7 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
199201
"get_max_seq_len": config.max_seq_len,
200202
"get_vocab_size": config.vocab_size,
201203
"get_n_layers": config.num_hidden_layers,
204+
"get_max_prefill_chunk": max_prefill,
202205
"use_kv_cache": True,
203206
"use_sdpa_with_kv_cache": False,
204207
"enable_dynamic_shape": True,

examples/models/gemma4_31b/main.cpp

Lines changed: 89 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,20 @@
3434
#include <cuda_runtime.h>
3535
#endif
3636

37-
DEFINE_string(model_path, "", "Path to model.pte.");
38-
DEFINE_string(data_path, "", "Path to model.ptd (CUDA tensor data).");
37+
DEFINE_string(model_path, "", "Model .pte file path.");
38+
DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend.");
3939
DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
4040
DEFINE_string(prompt, "Hello", "Prompt text.");
4141
DEFINE_string(
4242
prompt_file,
4343
"",
44-
"Optional path to a file with the prompt text (overrides --prompt).");
44+
"Path to file containing prompt text (overrides --prompt).");
4545
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
4646
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
4747
DEFINE_bool(
4848
cuda_graph,
4949
false,
50-
"Enable CUDA graph capture for the decode method.");
50+
"Enable CUDA graph capture for the decode method. CUDA only.");
5151

5252
namespace llm = ::executorch::extension::llm;
5353
using ::executorch::extension::from_blob;
@@ -57,8 +57,6 @@ using ::executorch::runtime::EValue;
5757

5858
using SizesType = executorch::aten::SizesType;
5959

60-
// The model performs sampling on-device and returns a [B, 1] float tensor
61-
// holding a token ID. Copy it to host and convert to uint64.
6260
static uint64_t read_token(const executorch::aten::Tensor& output) {
6361
const void* ptr = output.const_data_ptr();
6462
float val = 0.0f;
@@ -135,12 +133,14 @@ int main(int argc, char** argv) {
135133
/*temp_allocator=*/nullptr,
136134
/*share_memory_arenas=*/true);
137135

136+
// Get metadata
138137
auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get());
139138
if (metadata_result.error() != Error::Ok) {
140139
ET_LOG(Error, "Failed to read model metadata");
141140
return 1;
142141
}
143142

143+
#ifdef EXECUTORCH_BUILD_CUDA
144144
if (FLAGS_cuda_graph) {
145145
executorch::runtime::BackendOptions<2> cuda_opts;
146146
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");
@@ -154,14 +154,30 @@ int main(int argc, char** argv) {
154154
// load_method.
155155
{
156156
executorch::runtime::BackendOptions<1> backend_options;
157-
if (backend_options.set_option("weight_sharing_across_methods", true) !=
158-
Error::Ok ||
159-
executorch::runtime::set_option(
160-
"CudaBackend", backend_options.view()) != Error::Ok) {
161-
ET_LOG(Error, "Failed to enable weight_sharing_across_methods");
157+
auto set_err =
158+
backend_options.set_option("weight_sharing_across_methods", true);
159+
if (set_err != Error::Ok) {
160+
ET_LOG(
161+
Error,
162+
"Failed to construct weight_sharing_across_methods option: %d",
163+
static_cast<int>(set_err));
164+
return 1;
165+
}
166+
auto opt_err =
167+
executorch::runtime::set_option("CudaBackend", backend_options.view());
168+
if (opt_err != Error::Ok) {
169+
ET_LOG(
170+
Error,
171+
"Failed to enable weight_sharing_across_methods: %d",
172+
static_cast<int>(opt_err));
162173
return 1;
163174
}
164175
}
176+
#else
177+
if (FLAGS_cuda_graph) {
178+
ET_LOG(Info, "--cuda_graph ignored on non-CUDA build");
179+
}
180+
#endif
165181

166182
printf("Loading methods...\n");
167183
if (module->load_method("prefill") != Error::Ok) {
@@ -181,6 +197,7 @@ int main(int argc, char** argv) {
181197

182198
auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get());
183199

200+
// Read prompt from file or flag
184201
std::string prompt_text = FLAGS_prompt;
185202
if (!FLAGS_prompt_file.empty()) {
186203
std::ifstream f(FLAGS_prompt_file);
@@ -189,10 +206,11 @@ int main(int argc, char** argv) {
189206
Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str());
190207
return 1;
191208
}
192-
prompt_text.assign(
209+
prompt_text = std::string(
193210
(std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
194211
}
195212

213+
// Encode prompt
196214
auto encode_result = tokenizer->encode(prompt_text);
197215
if (!encode_result.ok()) {
198216
ET_LOG(Error, "Failed to encode prompt");
@@ -207,49 +225,66 @@ int main(int argc, char** argv) {
207225

208226
auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };
209227

210-
// Temperature: clamp 0 to a tiny epsilon so the divide in the exported
211-
// sampler stays well-defined. Gumbel noise then becomes negligible
212-
// relative to logit gaps and we get effectively-greedy sampling.
228+
#ifdef EXECUTORCH_BUILD_CUDA
229+
// CUDA build: model fuses the sampler. Pass temperature as a third input.
213230
float temp_val =
214231
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);
215232
auto temp_tensor =
216233
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
234+
#endif
217235

218236
// ---------------------------------------------------------------
219-
// Prefill
237+
// Prefill (chunked to respect ring-buffer KV cache limit)
220238
// ---------------------------------------------------------------
221-
std::string run_method = "prefill";
222-
if (num_prompt_tokens == 1) {
223-
// prefill was exported with min seq_len=2; decode handles T==1.
224-
run_method = "decode";
239+
// Sliding layers use a ring buffer sized to 2×sliding_window. A single
240+
// prefill call must not exceed this size, otherwise index_copy_ with
241+
// wrapped indices produces non-deterministic results on CUDA.
242+
int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1;
243+
{
244+
auto get_result = module->get("get_max_prefill_chunk");
245+
if (get_result.ok()) {
246+
max_prefill_chunk = get_result->toScalar().to<int64_t>();
247+
}
225248
}
226249

227-
std::vector<int64_t> token_data(prompt_tokens.begin(), prompt_tokens.end());
228-
std::vector<int64_t> pos_data(num_prompt_tokens);
229-
for (int64_t i = 0; i < num_prompt_tokens; i++) {
230-
pos_data[i] = i;
231-
}
232-
auto tokens_tensor = from_blob(
233-
token_data.data(),
234-
{1, S(num_prompt_tokens)},
235-
executorch::aten::ScalarType::Long);
236-
auto pos_tensor = from_blob(
237-
pos_data.data(),
238-
{S(num_prompt_tokens)},
239-
executorch::aten::ScalarType::Long);
240-
241-
std::vector<EValue> prefill_inputs = {
242-
EValue(tokens_tensor),
243-
EValue(pos_tensor),
244-
EValue(temp_tensor),
245-
};
246-
247-
auto prefill_result = module->execute(run_method, prefill_inputs);
248-
if (prefill_result.error() != Error::Ok) {
249-
ET_LOG(Error, "%s failed", run_method.c_str());
250-
return 1;
250+
uint64_t cur_token = 0;
251+
int64_t prefill_pos = 0;
252+
while (prefill_pos < num_prompt_tokens) {
253+
int64_t chunk_len =
254+
std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk);
255+
256+
std::string run_method = (chunk_len == 1) ? "decode" : "prefill";
257+
258+
std::vector<int64_t> token_data(
259+
prompt_tokens.begin() + prefill_pos,
260+
prompt_tokens.begin() + prefill_pos + chunk_len);
261+
std::vector<int64_t> pos_data(chunk_len);
262+
for (int64_t i = 0; i < chunk_len; i++) {
263+
pos_data[i] = prefill_pos + i;
264+
}
265+
auto tokens_tensor = from_blob(
266+
token_data.data(),
267+
{1, S(chunk_len)},
268+
executorch::aten::ScalarType::Long);
269+
auto pos_tensor = from_blob(
270+
pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long);
271+
272+
std::vector<EValue> prefill_inputs;
273+
prefill_inputs.push_back(EValue(tokens_tensor));
274+
prefill_inputs.push_back(EValue(pos_tensor));
275+
#ifdef EXECUTORCH_BUILD_CUDA
276+
prefill_inputs.push_back(EValue(temp_tensor));
277+
#endif
278+
279+
auto prefill_result = module->execute(run_method, prefill_inputs);
280+
if (prefill_result.error() != Error::Ok) {
281+
ET_LOG(
282+
Error, "%s failed at pos %" PRId64, run_method.c_str(), prefill_pos);
283+
return 1;
284+
}
285+
cur_token = read_token(prefill_result.get()[0].toTensor());
286+
prefill_pos += chunk_len;
251287
}
252-
uint64_t cur_token = read_token(prefill_result.get()[0].toTensor());
253288

254289
stats.prompt_eval_end_ms = llm::time_in_ms();
255290
double prefill_ms =
@@ -261,8 +296,9 @@ int main(int argc, char** argv) {
261296
num_prompt_tokens * 1000.0 / prefill_ms);
262297

263298
#ifdef EXECUTORCH_BUILD_CUDA
264-
// Make prefill's writes to the shared KV cache visible before decode
265-
// potentially runs on a different stream.
299+
// Synchronize CUDA device to ensure prefill's writes to shared mutable
300+
// buffers (KV cache) are visible to the decode method, which may run on
301+
// a different CUDA stream.
266302
cudaDeviceSynchronize();
267303
#endif
268304

@@ -282,11 +318,12 @@ int main(int argc, char** argv) {
282318
decode_token_data[0] = static_cast<int64_t>(cur_token);
283319
decode_pos_data[0] = pos;
284320

285-
std::vector<EValue> decode_inputs = {
286-
EValue(decode_tokens),
287-
EValue(decode_pos),
288-
EValue(temp_tensor),
289-
};
321+
std::vector<EValue> decode_inputs;
322+
decode_inputs.push_back(EValue(decode_tokens));
323+
decode_inputs.push_back(EValue(decode_pos));
324+
#ifdef EXECUTORCH_BUILD_CUDA
325+
decode_inputs.push_back(EValue(temp_tensor));
326+
#endif
290327

291328
auto decode_result = module->execute("decode", decode_inputs);
292329
if (decode_result.error() != Error::Ok) {

examples/models/gemma4_31b/model.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,20 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`,
105105
| Method | Input | Output (sampled) |
106106
|-----------|------------------------------------------------------------|------------------|
107107
| `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float |
108-
| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, max_seq_len-1] | `(1, 1)` float |
108+
| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float |
109109

110110
Both methods share the same KV-cache buffers via
111111
`MemoryPlanningPass(share_mutable_buffers=True)` and
112112
`emit_mutable_buffer_names=True`. The exported program performs Gumbel-max
113113
sampling on-device and returns a single token ID per call so the C++ runner
114114
only has to feed tokens.
115115

116+
Prefill length is capped to the ring-buffer KV cache size
117+
(`2 × sliding_window`) to avoid duplicate wrapped indices in
118+
`index_copy_`. The C++ runner chunks longer prompts automatically using
119+
the `get_max_prefill_chunk` constant method. Chunked prefill produces
120+
identical logits to sequential one-token-at-a-time prefill.
121+
116122
## Quantization
117123

118124
Three modules in `quant/`:

examples/models/gemma4_31b/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def update(
8989
k_val: torch.Tensor,
9090
v_val: torch.Tensor,
9191
) -> tuple[torch.Tensor, torch.Tensor]:
92+
# seq_len must not exceed buf_size, otherwise wrapped indices contain
93+
# duplicates and index_copy_ is non-deterministic on CUDA. The C++
94+
# runner must chunk prefill to respect this limit.
9295
wrapped = input_pos % self.buf_size
9396
self.k_cache.index_copy_(2, wrapped, k_val)
9497
self.v_cache.index_copy_(2, wrapped, v_val)

examples/models/gemma4_31b/quant/pack.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,13 @@ def pack_model(
6161

6262
module_weights: dict[str, dict[str, CanonicalQuantizedWeight]] = defaultdict(dict)
6363
for fqn, cw in quantized.items():
64-
parent_fqn, attr = fqn.rsplit(".", 1)
64+
parts = fqn.rsplit(".", 1)
65+
parent_fqn = parts[0] if len(parts) > 1 else ""
66+
attr = parts[-1]
6567
module_weights[parent_fqn][attr] = cw
6668

6769
for parent_fqn, weights in module_weights.items():
68-
module = model.get_submodule(parent_fqn)
70+
module = model.get_submodule(parent_fqn) if parent_fqn else model
6971
packer = packers.get(type(module))
7072
if packer is None:
7173
raise ValueError(

examples/models/gemma4_31b/quant/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class QuantConfig:
2121
"""Per-weight quantization parameters."""
2222

23-
bits: int # 4, 6, 8
23+
bits: int # 4, 8
2424
group_size: int # 32, 64, 128
2525
symmetric: bool # True = no zero point
2626
method: str # "min_max" | "hqq"

examples/models/gemma4_31b/quant/test_recipe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .recipe import QuantConfig, QuantRecipe, QuantRule
1414

1515
_Q4 = QuantConfig(4, 32, True, "min_max")
16-
_Q6 = QuantConfig(6, 32, False, "min_max")
16+
_Q8 = QuantConfig(8, 32, True, "min_max")
1717

1818

1919
class TestQuantRecipeGetConfig(unittest.TestCase):
@@ -23,13 +23,13 @@ class TestQuantRecipeGetConfig(unittest.TestCase):
2323
[
2424
(
2525
"first_match_wins",
26-
[QuantRule(r".*v_proj\.weight", _Q6), QuantRule(r".*\.weight", _Q4)],
26+
[QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)],
2727
"layers.0.self_attn.v_proj.weight",
28-
6,
28+
8,
2929
),
3030
(
3131
"fallthrough_to_catchall",
32-
[QuantRule(r".*v_proj\.weight", _Q6), QuantRule(r".*\.weight", _Q4)],
32+
[QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)],
3333
"layers.0.self_attn.q_proj.weight",
3434
4,
3535
),
@@ -85,13 +85,13 @@ def test_layer_filter(self):
8585
recipe = QuantRecipe(
8686
rules=[
8787
QuantRule(r".*norm\.weight", None),
88-
QuantRule(r".*\.(v_proj|down_proj)\.weight", _Q6, layers=edge),
88+
QuantRule(r".*\.(v_proj|down_proj)\.weight", _Q8, layers=edge),
8989
QuantRule(r".*\.weight", _Q4),
9090
]
9191
)
92-
# Edge v_proj → 6-bit
93-
self.assertEqual(recipe.get_config("layers.0.self_attn.v_proj.weight").bits, 6)
94-
self.assertEqual(recipe.get_config("layers.58.self_attn.v_proj.weight").bits, 6)
92+
# Edge v_proj → 8-bit
93+
self.assertEqual(recipe.get_config("layers.0.self_attn.v_proj.weight").bits, 8)
94+
self.assertEqual(recipe.get_config("layers.58.self_attn.v_proj.weight").bits, 8)
9595
# Middle v_proj → falls through → 4-bit
9696
self.assertEqual(recipe.get_config("layers.30.self_attn.v_proj.weight").bits, 4)
9797
# q_proj always 4-bit

0 commit comments

Comments
 (0)