Skip to content

Commit 1651349

Browse files
committed
Add chunked prefill to the eagle3 runner for prompts past the ring-cache cap
A single target prefill forward caps at 2*sliding_window (the sliding-ring KV limit), so prompts beyond that must be chunked. Loop the prompt in window-sized chunks for both the target prefill and the draft seed -- advancing input_pos, accumulating the flat global KV cache, and concatenating the per-chunk features -- then begin the speculative loop from the prompt end. Chunks are capped at the sliding window (not 2*window) for the same exactness reason as the gemma4 runner fix: a larger chunk truncates sliding attention across chunk boundaries (get_sliding_window from metadata, else max_prefill/2). Authored with assistance from Claude Code. ghstack-source-id: cda9ec8 ghstack-comment-id: 4734206971 Pull-Request: #20347
1 parent a355e4a commit 1651349

2 files changed

Lines changed: 116 additions & 32 deletions

File tree

examples/models/eagle3/export.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def _partitioner(name: str):
251251
"get_n_layers": target_config.num_hidden_layers,
252252
"get_max_prefill_chunk": max_prefill,
253253
"get_min_prefill_chunk": target_min,
254+
"get_sliding_window": target_config.sliding_window,
254255
"get_chain_len": chain_len,
255256
"get_draft_vocab_size": draft_vocab_size,
256257
"use_kv_cache": True,
@@ -307,9 +308,10 @@ def main() -> None:
307308
"--max-prefill",
308309
type=int,
309310
default=512,
310-
help="Max prefill length: AOTI compiles prefill kernels for up to this T "
311-
"and the whole prompt must fit in one prefill (the runner does not chunk). "
312-
"Smaller compiles faster.",
311+
help="Max prefill chunk: AOTI compiles prefill kernels for up to this T. "
312+
"The runner chunks the prompt into <= this many tokens per prefill (a "
313+
"longer prompt is fed as multiple chunks), so this bounds compile time, "
314+
"not prompt length. Smaller compiles faster.",
313315
)
314316
p.add_argument(
315317
"--chain", type=int, default=4, help="Draft chain length K (verify K+1)."

examples/models/eagle3/main.cpp

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@
3131
// forward per round (speedup ~= acceptance length tau).
3232
//
3333
// Features round-trip through the host between method calls (D2H copy + re-feed
34-
// as host tensors). They are small (<= max_prefill x H bf16), so the cost is
35-
// negligible next to the INT4 31B target forward, and it keeps device-tensor
36-
// lifetimes simple.
34+
// as host tensors), which keeps device-tensor lifetimes simple. Chunked prefill
35+
// concatenates per-position features for the whole prompt before draft seeding,
36+
// so the host buffer is prompt_len x H bf16 (~672 MiB at 64K context, H=5376),
37+
// scaling with prompt_len rather than max_prefill. That is negligible next to
38+
// the INT4 31B target forward at today's context lengths; stream draft seeding
39+
// as each prefill chunk completes if it becomes a memory/perf concern at larger
40+
// contexts or hidden sizes.
3741
//
3842
// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing
3943
// the CUDA env, and building the eagle3-cuda preset):
@@ -267,6 +271,22 @@ int main(int argc, char** argv) {
267271
};
268272
const int64_t chain_len = meta("get_chain_len");
269273
const int64_t max_prefill = meta("get_max_prefill_chunk");
274+
// Prefill chunks must not exceed the sliding window: a chunk larger than the
275+
// window overflows the 2*window ring cache across chunk boundaries,
276+
// truncating sliding attention for the first ~(chunk-window) queries of each
277+
// chunk (the global flat-cache layers stay exact). Prefer get_sliding_window
278+
// when the export provides it, else fall back to max_prefill/2.
279+
int64_t prefill_chunk = max_prefill / 2;
280+
{
281+
auto sw = module->get("get_sliding_window");
282+
if (sw.ok()) {
283+
prefill_chunk = sw->toScalar().to<int64_t>();
284+
}
285+
}
286+
// Also bound by the exported prefill range: get_max_prefill_chunk is the
287+
// largest T the prefill kernels were compiled for, which need not be
288+
// 2*sliding_window (small --max-prefill with a larger window), so cap here.
289+
prefill_chunk = std::min(prefill_chunk, max_prefill);
270290
const int64_t min_prefill = meta("get_min_prefill_chunk");
271291
const int64_t max_seq_len = meta("get_max_seq_len");
272292
const int64_t K_req = (FLAGS_chain > 0) ? FLAGS_chain : chain_len;
@@ -307,16 +327,16 @@ int main(int argc, char** argv) {
307327
prompt.insert(prompt.begin(), static_cast<int64_t>(FLAGS_bos_id));
308328
}
309329
const int64_t L = static_cast<int64_t>(prompt.size());
310-
// The runner does not chunk: the whole prompt must fit one prefill, and its
311-
// length must be within the exported prefill range [min_prefill,
312-
// max_prefill].
313-
if (L > max_prefill) {
330+
// A single prefill forward caps at max_prefill (the sliding-ring 2*window
331+
// limit), so prompts beyond that are looped in <= max_prefill chunks below;
332+
// the flat global KV cache accumulates across chunks. The prompt only has to
333+
// fit the exported context (its features then seed the speculative loop).
334+
if (L >= max_seq_len) {
314335
ET_LOG(
315336
Error,
316-
"Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64
317-
"; this runner does not chunk prefill.",
337+
"Prompt (%" PRId64 " tokens) does not fit max_seq_len %" PRId64,
318338
L,
319-
max_prefill);
339+
max_seq_len);
320340
return 1;
321341
}
322342
if (L < min_prefill) {
@@ -332,8 +352,9 @@ int main(int argc, char** argv) {
332352
// The prefill bonus token is always emittable (no KV write past the prompt).
333353
// Each speculative round, however, writes a K-token verify window, so it
334354
// needs anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap
335-
// the total at the positions available; max_new >= 1 since L <= max_prefill <
336-
// max_seq_len.
355+
// the total at the positions available; max_new >= 1 since L < max_seq_len
356+
// (L may exceed max_prefill -- the prompt is fed as chunks; L >= max_seq_len
357+
// is rejected above).
337358
int64_t max_new = std::min<int64_t>(FLAGS_max_new_tokens, max_seq_len - L);
338359
printf(
339360
"Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 "\n",
@@ -432,22 +453,48 @@ int main(int argc, char** argv) {
432453
stats.model_load_end_ms = llm::time_in_ms();
433454
stats.inference_start_ms = stats.model_load_end_ms;
434455

435-
// --- Prefill: target over the prompt -> bonus token + per-position feature.
436-
// ---
437-
tok_buf = prompt;
438-
pos_buf.resize(L);
439-
for (int64_t i = 0; i < L; i++) {
440-
pos_buf[i] = i;
441-
}
442-
auto pf = module->execute(
443-
"prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))});
444-
if (pf.error() != Error::Ok) {
445-
ET_LOG(Error, "prefill failed");
446-
return 1;
456+
// The exported prefill forward accepts T in [min_prefill, max_prefill]; pick
457+
// the next chunk so the running tail never drops below min_prefill (it would
458+
// be an out-of-range shape). All but the last one or two chunks are
459+
// max_prefill.
460+
auto next_chunk = [&](int64_t done) {
461+
int64_t remaining = L - done;
462+
int64_t len = std::min(remaining, prefill_chunk);
463+
if (remaining - len > 0 && remaining - len < min_prefill) {
464+
len = remaining - min_prefill;
465+
}
466+
return len;
467+
};
468+
469+
// --- Prefill: target over the prompt (chunked to respect the prefill cap) ->
470+
// bonus token + per-position feature. The flat target KV cache accumulates
471+
// across chunks; the bonus token is the last chunk's output, and the
472+
// per-position features of every chunk are concatenated to seed the draft.
473+
HostFeature feat_prompt;
474+
int64_t anchor = 0;
475+
int64_t prefill_pos = 0;
476+
while (prefill_pos < L) {
477+
int64_t chunk_len = next_chunk(prefill_pos);
478+
tok_buf.assign(
479+
prompt.begin() + prefill_pos, prompt.begin() + prefill_pos + chunk_len);
480+
pos_buf.resize(chunk_len);
481+
for (int64_t i = 0; i < chunk_len; i++) {
482+
pos_buf[i] = prefill_pos + i;
483+
}
484+
auto pf = module->execute(
485+
"prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))});
486+
if (pf.error() != Error::Ok) {
487+
ET_LOG(Error, "prefill failed at pos %" PRId64, prefill_pos);
488+
return 1;
489+
}
490+
anchor = read_ids(pf->at(0).toTensor())[0]; // bonus token after the prompt
491+
HostFeature chunk_feat = read_feature(pf->at(1).toTensor());
492+
feat_prompt.H = chunk_feat.H;
493+
feat_prompt.T += chunk_feat.T;
494+
feat_prompt.data.insert(
495+
feat_prompt.data.end(), chunk_feat.data.begin(), chunk_feat.data.end());
496+
prefill_pos += chunk_len;
447497
}
448-
int64_t anchor =
449-
read_ids(pf->at(0).toTensor())[0]; // bonus token at position L
450-
HostFeature feat_prompt = read_feature(pf->at(1).toTensor());
451498
const int64_t H = feat_prompt.H;
452499
int64_t anchor_pos = L;
453500

@@ -474,10 +521,45 @@ int main(int argc, char** argv) {
474521
if (speculate) {
475522
// Seed the first chain (shifted): draft slot p pairs feat_prompt[p] with
476523
// token_{p+1}; the last slot pairs feat_prompt[L-1] with the bonus and
477-
// predicts position L+1.
524+
// predicts position L+1. Seed in <= max_prefill chunks (draft_decode shares
525+
// the prefill shape range), each contiguous from the previous so the draft
526+
// KV cache fills; the last chunk's last row predicts proposal 0 and carries
527+
// the recurrent feature, then K-1 recurrent steps follow (mirroring chain).
478528
std::vector<int64_t> seed_tokens(prompt.begin() + 1, prompt.end());
479529
seed_tokens.push_back(anchor);
480-
proposals = chain(seed_tokens, feat_prompt, 0);
530+
std::vector<int64_t> ids;
531+
HostFeature last_g;
532+
for (int64_t seed_pos = 0; seed_pos < L;) {
533+
int64_t chunk_len = next_chunk(seed_pos);
534+
std::vector<int64_t> chunk_tokens(
535+
seed_tokens.begin() + seed_pos,
536+
seed_tokens.begin() + seed_pos + chunk_len);
537+
draft_decode(
538+
chunk_tokens,
539+
feat_prompt.data.data() + seed_pos * H,
540+
chunk_len,
541+
H,
542+
seed_pos,
543+
ids,
544+
last_g);
545+
seed_pos += chunk_len;
546+
}
547+
proposals.push_back(ids.back());
548+
int64_t last_pos = L - 1;
549+
for (int64_t k = 1; k < K; k++) {
550+
std::vector<int64_t> step_ids;
551+
HostFeature step_g;
552+
draft_decode(
553+
{proposals.back()},
554+
last_g.data.data(),
555+
1,
556+
last_g.H,
557+
last_pos + k,
558+
step_ids,
559+
step_g);
560+
proposals.push_back(step_ids[0]);
561+
last_g = step_g;
562+
}
481563
}
482564

483565
// Stable buffers for target_verify (fixed length K+1) so the CUDA graph

0 commit comments

Comments
 (0)