Skip to content

Commit a355e4a

Browse files
committed
Fix sliding-window chunked prefill in the gemma4-31B runner
The runner chunked prefill at get_max_prefill_chunk = 2*sliding_window (2048). A chunk larger than the window overflows the 2*window ring KV cache across chunk boundaries: after writing a 2048-token chunk the ring holds only the most recent 2048 positions, so the first ~(chunk - window) queries of every chunk after the first lose the tail of the previous chunk that is still inside their 1024 window. Those sliding-layer queries then attend over a truncated window, which propagates into their hidden states and the global KV those positions write, changing the output. The global flat-cache layers are unaffected. Cap the prefill chunk at the sliding window: get_sliding_window from metadata (now exported), else max_prefill/2 since the export sets max_prefill = 2*sliding_window. Decode is unaffected. Adds --max_prefill_chunk to override the chunk size for testing. Authored with assistance from Claude Code. ghstack-source-id: 1c793dd ghstack-comment-id: 4734206312 Pull-Request: #20346
1 parent 5ac44d4 commit a355e4a

2 files changed

Lines changed: 78 additions & 6 deletions

File tree

examples/models/gemma4_31b/export.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,9 @@ def _export_cuda(
195195

196196
# Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M).
197197
max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)
198-
seq_dim = Dim("seq_len", min=5, max=max_prefill)
199-
print(f"Exporting prefill (T in [2, {max_prefill}])...")
198+
min_prefill = 5
199+
seq_dim = Dim("seq_len", min=min_prefill, max=max_prefill)
200+
print(f"Exporting prefill (T in [{min_prefill}, {max_prefill}])...")
200201
with torch.no_grad():
201202
prefill_ep = export(
202203
model,
@@ -255,6 +256,8 @@ def _export_cuda(
255256
"get_vocab_size": config.vocab_size,
256257
"get_n_layers": config.num_hidden_layers,
257258
"get_max_prefill_chunk": max_prefill,
259+
"get_min_prefill_chunk": min_prefill,
260+
"get_sliding_window": config.sliding_window,
258261
"use_kv_cache": True,
259262
"use_sdpa_with_kv_cache": False,
260263
"enable_dynamic_shape": True,
@@ -379,6 +382,7 @@ def _export_mlx(
379382
"get_vocab_size": config.vocab_size,
380383
"get_n_layers": config.num_hidden_layers,
381384
"get_max_prefill_chunk": max_prefill,
385+
"get_sliding_window": config.sliding_window,
382386
"use_kv_cache": True,
383387
"use_sdpa_with_kv_cache": False,
384388
"enable_dynamic_shape": True,

examples/models/gemma4_31b/main.cpp

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
7171
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
7272
DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2).");
7373
DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
74+
DEFINE_int32(
75+
max_prefill_chunk,
76+
0,
77+
"Override the prefill chunk size (<=0 uses metadata). Experiment: chunking "
78+
"above sliding_window is inexact for sliding layers across boundaries.");
7479
DEFINE_bool(
7580
raw_prompt,
7681
false,
@@ -180,13 +185,55 @@ int main(int argc, char** argv) {
180185
return 1;
181186
}
182187

183-
int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1;
188+
int64_t exported_max_prefill = (*metadata_result)[llm::kMaxSeqLen] - 1;
184189
{
185190
auto get_result = module->get("get_max_prefill_chunk");
186191
if (get_result.ok()) {
187-
max_prefill_chunk = get_result->toScalar().to<int64_t>();
192+
exported_max_prefill = get_result->toScalar().to<int64_t>();
188193
}
189194
}
195+
// Cap prefill chunks at the sliding window: a chunk larger than the window
196+
// overflows the 2*window ring cache across chunk boundaries, truncating
197+
// sliding attention for the first ~(chunk-window) queries of each chunk (the
198+
// global flat-cache layers stay exact). The export sets max_prefill =
199+
// 2*sliding_window, so window = max_prefill/2 (prefer get_sliding_window
200+
// metadata when present).
201+
int64_t sliding_window = exported_max_prefill / 2;
202+
{
203+
auto sw = module->get("get_sliding_window");
204+
if (sw.ok()) {
205+
sliding_window = sw->toScalar().to<int64_t>();
206+
}
207+
}
208+
int64_t max_prefill_chunk = std::min(sliding_window, exported_max_prefill);
209+
if (FLAGS_max_prefill_chunk > 0) {
210+
max_prefill_chunk =
211+
std::min<int64_t>(FLAGS_max_prefill_chunk, exported_max_prefill);
212+
}
213+
// The exported prefill accepts T in [min_prefill, max_prefill]; a final chunk
214+
// shorter than min_prefill (and > 1) is an out-of-range shape. Read the lower
215+
// bound so chunking can avoid it (fallback 1 keeps older exports working: a
216+
// length-1 tail already routes to decode).
217+
int64_t min_prefill = 1;
218+
{
219+
auto r = module->get("get_min_prefill_chunk");
220+
if (r.ok()) {
221+
min_prefill = r->toScalar().to<int64_t>();
222+
}
223+
}
224+
// A --max_prefill_chunk below the exported minimum has no valid prefill shape
225+
// (and a cap of 1 would make the tail adjustment compute chunk_len == 0 and
226+
// loop forever), so reject it rather than feed an out-of-range / zero chunk.
227+
if (FLAGS_max_prefill_chunk > 0 && max_prefill_chunk < min_prefill) {
228+
ET_LOG(
229+
Error,
230+
"--max_prefill_chunk (%d) is below the exported prefill minimum "
231+
"(%" PRId64 "); use a value >= %" PRId64 " or omit it.",
232+
FLAGS_max_prefill_chunk,
233+
min_prefill,
234+
min_prefill);
235+
return 1;
236+
}
190237

191238
auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };
192239

@@ -295,6 +342,21 @@ int main(int argc, char** argv) {
295342
printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens);
296343
stats.num_prompt_tokens = num_prompt_tokens;
297344

345+
// A prompt of 2..min_prefill-1 tokens has no valid prefill shape (the CUDA
346+
// export specializes prefill to T >= min_prefill) and is too long for the
347+
// single-token decode path, so reject it. A 1-token prompt is fine: it goes
348+
// through decode below.
349+
if (num_prompt_tokens > 1 && num_prompt_tokens < min_prefill) {
350+
ET_LOG(
351+
Error,
352+
"Prompt (%" PRId64
353+
" tokens) is below the exported prefill minimum %" PRId64
354+
"; use a longer prompt.",
355+
num_prompt_tokens,
356+
min_prefill);
357+
return 1;
358+
}
359+
298360
stats.inference_start_ms = llm::time_in_ms();
299361

300362
// ---------------------------------------------------------------
@@ -309,8 +371,14 @@ int main(int argc, char** argv) {
309371
TensorPtr device_out_token;
310372
#endif
311373
while (prefill_pos < num_prompt_tokens) {
312-
int64_t chunk_len =
313-
std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk);
374+
int64_t remaining = num_prompt_tokens - prefill_pos;
375+
int64_t chunk_len = std::min(remaining, max_prefill_chunk);
376+
// Shrink this chunk so the tail it leaves is never in (1, min_prefill):
377+
// such a tail would be an out-of-range prefill shape. A length-1 tail is
378+
// fine (routed to decode below); a >= min_prefill tail is fine too.
379+
if (remaining - chunk_len > 1 && remaining - chunk_len < min_prefill) {
380+
chunk_len = remaining - min_prefill;
381+
}
314382

315383
std::vector<int64_t> token_data(
316384
prompt_tokens.begin() + prefill_pos,

0 commit comments

Comments
 (0)