3131// forward per round (speedup ~= acceptance length tau).
3232//
3333// Features round-trip through the host between method calls (D2H copy + re-feed
34- // as host tensors). They are small (<= max_prefill x H bf16), so the cost is
35- // negligible next to the INT4 31B target forward, and it keeps device-tensor
36- // lifetimes simple.
34+ // as host tensors), which keeps device-tensor lifetimes simple. Chunked prefill
35+ // concatenates per-position features for the whole prompt before draft seeding,
36+ // so the host buffer is prompt_len x H bf16 (~672 MiB at 64K context, H=5376),
37+ // scaling with prompt_len rather than max_prefill. That is negligible next to
38+ // the INT4 31B target forward at today's context lengths; stream draft seeding
39+ // as each prefill chunk completes if it becomes a memory/perf concern at larger
40+ // contexts or hidden sizes.
3741//
3842// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing
3943// the CUDA env, and building the eagle3-cuda preset):
@@ -267,6 +271,22 @@ int main(int argc, char** argv) {
267271 };
268272 const int64_t chain_len = meta (" get_chain_len" );
269273 const int64_t max_prefill = meta (" get_max_prefill_chunk" );
274+ // Prefill chunks must not exceed the sliding window: a chunk larger than the
275+ // window overflows the 2*window ring cache across chunk boundaries,
276+ // truncating sliding attention for the first ~(chunk-window) queries of each
277+ // chunk (the global flat-cache layers stay exact). Prefer get_sliding_window
278+ // when the export provides it, else fall back to max_prefill/2.
279+ int64_t prefill_chunk = max_prefill / 2 ;
280+ {
281+ auto sw = module ->get (" get_sliding_window" );
282+ if (sw.ok ()) {
283+ prefill_chunk = sw->toScalar ().to <int64_t >();
284+ }
285+ }
286+ // Also bound by the exported prefill range: get_max_prefill_chunk is the
287+ // largest T the prefill kernels were compiled for, which need not be
288+ // 2*sliding_window (small --max-prefill with a larger window), so cap here.
289+ prefill_chunk = std::min (prefill_chunk, max_prefill);
270290 const int64_t min_prefill = meta (" get_min_prefill_chunk" );
271291 const int64_t max_seq_len = meta (" get_max_seq_len" );
272292 const int64_t K_req = (FLAGS_chain > 0 ) ? FLAGS_chain : chain_len;
@@ -307,16 +327,16 @@ int main(int argc, char** argv) {
307327 prompt.insert (prompt.begin (), static_cast <int64_t >(FLAGS_bos_id));
308328 }
309329 const int64_t L = static_cast <int64_t >(prompt.size ());
310- // The runner does not chunk: the whole prompt must fit one prefill, and its
311- // length must be within the exported prefill range [min_prefill,
312- // max_prefill].
313- if (L > max_prefill) {
330+ // A single prefill forward caps at max_prefill (the sliding-ring 2*window
331+ // limit), so prompts beyond that are looped in <= max_prefill chunks below;
332+ // the flat global KV cache accumulates across chunks. The prompt only has to
333+ // fit the exported context (its features then seed the speculative loop).
334+ if (L >= max_seq_len) {
314335 ET_LOG (
315336 Error,
316- " Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64
317- " ; this runner does not chunk prefill." ,
337+ " Prompt (%" PRId64 " tokens) does not fit max_seq_len %" PRId64,
318338 L,
319- max_prefill );
339+ max_seq_len );
320340 return 1 ;
321341 }
322342 if (L < min_prefill) {
@@ -332,8 +352,9 @@ int main(int argc, char** argv) {
332352 // The prefill bonus token is always emittable (no KV write past the prompt).
333353 // Each speculative round, however, writes a K-token verify window, so it
334354 // needs anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap
335- // the total at the positions available; max_new >= 1 since L <= max_prefill <
336- // max_seq_len.
355+ // the total at the positions available; max_new >= 1 since L < max_seq_len
356+ // (L may exceed max_prefill -- the prompt is fed as chunks; L >= max_seq_len
357+ // is rejected above).
337358 int64_t max_new = std::min<int64_t >(FLAGS_max_new_tokens, max_seq_len - L);
338359 printf (
339360 " Prompt tokens: %" PRId64 " , chain K=%" PRId64 " , max_new=%" PRId64 " \n " ,
@@ -432,22 +453,48 @@ int main(int argc, char** argv) {
432453 stats.model_load_end_ms = llm::time_in_ms ();
433454 stats.inference_start_ms = stats.model_load_end_ms ;
434455
435- // --- Prefill: target over the prompt -> bonus token + per-position feature.
436- // ---
437- tok_buf = prompt;
438- pos_buf.resize (L);
439- for (int64_t i = 0 ; i < L; i++) {
440- pos_buf[i] = i;
441- }
442- auto pf = module ->execute (
443- " prefill" , {EValue (long_tensor (tok_buf)), EValue (pos_tensor (pos_buf))});
444- if (pf.error () != Error::Ok) {
445- ET_LOG (Error, " prefill failed" );
446- return 1 ;
456+ // The exported prefill forward accepts T in [min_prefill, max_prefill]; pick
457+ // the next chunk so the running tail never drops below min_prefill (it would
458+ // be an out-of-range shape). All but the last one or two chunks are
459+ // max_prefill.
460+ auto next_chunk = [&](int64_t done) {
461+ int64_t remaining = L - done;
462+ int64_t len = std::min (remaining, prefill_chunk);
463+ if (remaining - len > 0 && remaining - len < min_prefill) {
464+ len = remaining - min_prefill;
465+ }
466+ return len;
467+ };
468+
469+ // --- Prefill: target over the prompt (chunked to respect the prefill cap) ->
470+ // bonus token + per-position feature. The flat target KV cache accumulates
471+ // across chunks; the bonus token is the last chunk's output, and the
472+ // per-position features of every chunk are concatenated to seed the draft.
473+ HostFeature feat_prompt;
474+ int64_t anchor = 0 ;
475+ int64_t prefill_pos = 0 ;
476+ while (prefill_pos < L) {
477+ int64_t chunk_len = next_chunk (prefill_pos);
478+ tok_buf.assign (
479+ prompt.begin () + prefill_pos, prompt.begin () + prefill_pos + chunk_len);
480+ pos_buf.resize (chunk_len);
481+ for (int64_t i = 0 ; i < chunk_len; i++) {
482+ pos_buf[i] = prefill_pos + i;
483+ }
484+ auto pf = module ->execute (
485+ " prefill" , {EValue (long_tensor (tok_buf)), EValue (pos_tensor (pos_buf))});
486+ if (pf.error () != Error::Ok) {
487+ ET_LOG (Error, " prefill failed at pos %" PRId64, prefill_pos);
488+ return 1 ;
489+ }
490+ anchor = read_ids (pf->at (0 ).toTensor ())[0 ]; // bonus token after the prompt
491+ HostFeature chunk_feat = read_feature (pf->at (1 ).toTensor ());
492+ feat_prompt.H = chunk_feat.H ;
493+ feat_prompt.T += chunk_feat.T ;
494+ feat_prompt.data .insert (
495+ feat_prompt.data .end (), chunk_feat.data .begin (), chunk_feat.data .end ());
496+ prefill_pos += chunk_len;
447497 }
448- int64_t anchor =
449- read_ids (pf->at (0 ).toTensor ())[0 ]; // bonus token at position L
450- HostFeature feat_prompt = read_feature (pf->at (1 ).toTensor ());
451498 const int64_t H = feat_prompt.H ;
452499 int64_t anchor_pos = L;
453500
@@ -474,10 +521,45 @@ int main(int argc, char** argv) {
474521 if (speculate) {
475522 // Seed the first chain (shifted): draft slot p pairs feat_prompt[p] with
476523 // token_{p+1}; the last slot pairs feat_prompt[L-1] with the bonus and
477- // predicts position L+1.
524+ // predicts position L+1. Seed in <= max_prefill chunks (draft_decode shares
525+ // the prefill shape range), each contiguous from the previous so the draft
526+ // KV cache fills; the last chunk's last row predicts proposal 0 and carries
527+ // the recurrent feature, then K-1 recurrent steps follow (mirroring chain).
478528 std::vector<int64_t > seed_tokens (prompt.begin () + 1 , prompt.end ());
479529 seed_tokens.push_back (anchor);
480- proposals = chain (seed_tokens, feat_prompt, 0 );
530+ std::vector<int64_t > ids;
531+ HostFeature last_g;
532+ for (int64_t seed_pos = 0 ; seed_pos < L;) {
533+ int64_t chunk_len = next_chunk (seed_pos);
534+ std::vector<int64_t > chunk_tokens (
535+ seed_tokens.begin () + seed_pos,
536+ seed_tokens.begin () + seed_pos + chunk_len);
537+ draft_decode (
538+ chunk_tokens,
539+ feat_prompt.data .data () + seed_pos * H,
540+ chunk_len,
541+ H,
542+ seed_pos,
543+ ids,
544+ last_g);
545+ seed_pos += chunk_len;
546+ }
547+ proposals.push_back (ids.back ());
548+ int64_t last_pos = L - 1 ;
549+ for (int64_t k = 1 ; k < K; k++) {
550+ std::vector<int64_t > step_ids;
551+ HostFeature step_g;
552+ draft_decode (
553+ {proposals.back ()},
554+ last_g.data .data (),
555+ 1 ,
556+ last_g.H ,
557+ last_pos + k,
558+ step_ids,
559+ step_g);
560+ proposals.push_back (step_ids[0 ]);
561+ last_g = step_g;
562+ }
481563 }
482564
483565 // Stable buffers for target_verify (fixed length K+1) so the CUDA graph
0 commit comments