Skip to content

Commit 1833677

Browse files
committed
Add EAGLE-3 end-to-end speculative-decode test
Tiny matching target + draft on CPU drive the shifted method-flow the runner implements -- prefill, draft chain, target_verify, accept, reseed -- and the generated tokens are checked against greedy decoding (lossless by construction). Forced-acceptance cases pin the partial, full, and accepted-EOS paths plus the one-token budget; the random-weight loop alone can leave them uncovered. This covers the decoding algorithm on CPU only; the CUDA/AOTI export (export.py) and the C++ runner are exercised manually. Registered in pytest.ini next to the draft/speculator unit tests. Authored with assistance from Claude Code. ghstack-source-id: 5179d07 ghstack-comment-id: 4661635123 Pull-Request: #20157
1 parent ac69057 commit 1833677

2 files changed

Lines changed: 262 additions & 0 deletions

File tree

examples/models/eagle3/test_e2e.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""End-to-end EAGLE-3 speculative-decoding smoke test (CPU, no export needed).
8+
9+
Builds tiny matching target + draft models and drives the shifted (vLLM-EAGLE)
10+
method-flow the C++ runner uses -- prefill -> draft chain -> target_verify ->
11+
accept -> reseed -> repeat -- checking the generated tokens equal greedy target
12+
decoding (lossless by construction). Forced-acceptance cases pin the partial,
13+
full, and accepted-EOS paths plus the one-token budget; the random-weight loop
14+
alone can leave them uncovered.
15+
16+
This is CPU eager coverage of the decoding *algorithm*, not the C++ runner
17+
itself: tokenizer integration, device buffers, CUDA-graph capture, and the real
18+
CUDA/AOTI export are exercised manually (examples/models/eagle3/export.py + the
19+
eagle3-cuda runner) and remain tracked as future automated CUDA coverage.
20+
"""
21+
22+
import torch
23+
24+
from executorch.examples.models.eagle3.draft import Eagle3Config, Eagle3Draft
25+
from executorch.examples.models.eagle3.speculator import Eagle3Speculator
26+
from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig
27+
28+
_TARGET_VOCAB = 128
29+
30+
31+
def _build():
32+
torch.manual_seed(0)
33+
target = (
34+
Gemma4_31B(
35+
Gemma4_31BConfig(
36+
vocab_size=_TARGET_VOCAB,
37+
hidden_size=32,
38+
intermediate_size=64,
39+
num_hidden_layers=6,
40+
num_attention_heads=4,
41+
num_key_value_heads=2,
42+
head_dim=8,
43+
num_global_key_value_heads=1,
44+
global_head_dim=8,
45+
sliding_window=64,
46+
max_seq_len=128,
47+
)
48+
)
49+
.to(torch.float32)
50+
.eval()
51+
)
52+
draft = (
53+
Eagle3Draft(
54+
Eagle3Config(
55+
hidden_size=32,
56+
target_hidden_size=32,
57+
intermediate_size=64,
58+
num_attention_heads=4,
59+
num_key_value_heads=2,
60+
head_dim=8,
61+
draft_vocab_size=64,
62+
target_vocab_size=_TARGET_VOCAB,
63+
aux_hidden_state_layers=[0, 1, 3],
64+
max_seq_len=128,
65+
has_own_embed=True,
66+
)
67+
)
68+
.to(torch.float32)
69+
.eval()
70+
)
71+
return Eagle3Speculator(target, draft), target
72+
73+
74+
def _toks(ids):
75+
return torch.tensor([ids], dtype=torch.long)
76+
77+
78+
def _reset_kv(target):
79+
for name, buf in target.named_buffers():
80+
if ".kv_cache." in name:
81+
buf.zero_()
82+
83+
84+
@torch.no_grad()
85+
def _greedy(target, prompt, n):
86+
seq, out = list(prompt), []
87+
for _ in range(n):
88+
_reset_kv(target)
89+
logits, _ = target.forward_logits_taps(
90+
_toks(seq), torch.arange(len(seq)), last_logits_only=True
91+
)
92+
t = int(logits[:, -1].argmax())
93+
seq.append(t)
94+
out.append(t)
95+
return out
96+
97+
98+
def _accept_len(proposals, verify_ids):
99+
"""Greedy acceptance: count leading proposals matching the verifier ids."""
100+
a = 0
101+
for j, p in enumerate(proposals):
102+
if p != int(verify_ids[0, j]):
103+
break
104+
a += 1
105+
return a
106+
107+
108+
def _truncate_at_eos(tokens, eos_ids):
109+
"""Cut at the first stop token (inclusive); returns (tokens, hit_eos)."""
110+
for i, t in enumerate(tokens):
111+
if t in eos_ids:
112+
return tokens[: i + 1], True
113+
return tokens, False
114+
115+
116+
@torch.no_grad()
117+
def _speculative_decode(
118+
spec, prompt, K, num_gen, force=None, eos_ids=None, accept_out=None
119+
):
120+
"""The shifted one-target-forward-per-round loop the C++ runner implements.
121+
122+
``force(emitted) -> list[K]`` overrides the draft's proposal *values* (the
123+
draft chain is still run to reseed) so tests can pin the acceptance count.
124+
``eos_ids`` truncates a round at the first emitted stop token (matching the
125+
runner). Per-round acceptance counts are appended to ``accept_out``.
126+
"""
127+
target = spec.target
128+
_reset_kv(target)
129+
spec.draft.reset_cache()
130+
eos_ids = eos_ids or set()
131+
L = len(prompt)
132+
bonus, feat = spec.prefill(_toks(prompt), torch.arange(L))
133+
anchor, anchor_pos = int(bonus), L
134+
emitted = [anchor]
135+
if num_gen <= 1 or anchor in eos_ids:
136+
return emitted[:num_gen] # prefill bonus suffices; no draft round runs
137+
138+
def chain(seed_tokens, seed_feat, seed_pos):
139+
tids, g = spec.draft_decode(_toks(seed_tokens), seed_feat, seed_pos)
140+
proposals = [int(tids[0, -1])]
141+
last = int(seed_pos[-1])
142+
tok, f = tids[:, -1:], g[:, -1:]
143+
for k in range(1, K):
144+
tids, g = spec.draft_decode(tok, f, torch.tensor([last + k]))
145+
proposals.append(int(tids[0, 0]))
146+
tok, f = tids, g
147+
return proposals
148+
149+
proposals = chain(prompt[1:] + [anchor], feat, torch.arange(L))
150+
if force is not None:
151+
proposals = force(emitted)
152+
while len(emitted) < num_gen:
153+
vids, vfeat = spec.target_verify(
154+
_toks([anchor] + proposals), torch.arange(anchor_pos, anchor_pos + K + 1)
155+
)
156+
a = _accept_len(proposals, vids)
157+
if accept_out is not None:
158+
accept_out.append(a)
159+
corrected = int(vids[0, a])
160+
new = (proposals[:a] + [corrected])[: num_gen - len(emitted)]
161+
new, hit_eos = _truncate_at_eos(new, eos_ids)
162+
emitted += new
163+
if hit_eos or len(emitted) >= num_gen:
164+
break
165+
proposals = chain(
166+
proposals[:a] + [corrected],
167+
vfeat[:, : a + 1],
168+
torch.arange(anchor_pos, anchor_pos + a + 1),
169+
)
170+
anchor, anchor_pos = corrected, anchor_pos + 1 + a
171+
if force is not None:
172+
proposals = force(emitted)
173+
return emitted[:num_gen]
174+
175+
176+
_PROMPT = [2, 7, 3, 21, 9, 14]
177+
178+
179+
def test_speculative_decode_matches_greedy_e2e():
180+
spec, target = _build()
181+
num_gen = 16
182+
got = _speculative_decode(spec, _PROMPT, K=4, num_gen=num_gen)
183+
assert len(got) == num_gen
184+
assert got == _greedy(target, _PROMPT, num_gen)
185+
186+
187+
def test_full_acceptance_loop_is_lossless():
188+
# Force every round to fully accept (a == K) by proposing the target's own
189+
# greedy continuation. This deterministically exercises the a == K reseed and
190+
# the folded-bonus path across rounds, which a random-weight run may never hit.
191+
spec, target = _build()
192+
K, num_gen = 4, 16
193+
G = _greedy(target, _PROMPT, num_gen + K + 1)
194+
accepts = []
195+
got = _speculative_decode(
196+
spec,
197+
_PROMPT,
198+
K=K,
199+
num_gen=num_gen,
200+
force=lambda emitted: G[len(emitted) : len(emitted) + K],
201+
accept_out=accepts,
202+
)
203+
assert got == G[:num_gen]
204+
assert accepts and all(a == K for a in accepts)
205+
206+
207+
def test_partial_acceptance_loop_is_lossless():
208+
# Force every round to accept K-1 (0 < a < K): greedy for the first K-1
209+
# proposals, then a deliberately wrong token. The corrected token must be the
210+
# greedy token at the mismatch, so the loop stays lossless.
211+
spec, target = _build()
212+
K, num_gen = 4, 16
213+
G = _greedy(target, _PROMPT, num_gen + K + 1)
214+
215+
def force(emitted):
216+
e = len(emitted)
217+
good = G[e : e + K - 1]
218+
wrong = (G[e + K - 1] + 1) % _TARGET_VOCAB
219+
return good + [wrong]
220+
221+
accepts = []
222+
got = _speculative_decode(
223+
spec, _PROMPT, K=K, num_gen=num_gen, force=force, accept_out=accepts
224+
)
225+
assert got == G[:num_gen]
226+
assert accepts and all(0 < a < K for a in accepts)
227+
228+
229+
def test_accepted_proposal_eos_stops_emission():
230+
# An accepted proposal (not the prefill bonus or corrected token) that is a
231+
# stop token must end emission immediately, with nothing emitted after it.
232+
spec, target = _build()
233+
K, num_gen = 4, 16
234+
G = _greedy(target, _PROMPT, num_gen + K + 1)
235+
eos = {G[2]} # the 3rd accepted token of the first full-acceptance round
236+
got = _speculative_decode(
237+
spec,
238+
_PROMPT,
239+
K=K,
240+
num_gen=num_gen,
241+
force=lambda emitted: G[len(emitted) : len(emitted) + K],
242+
eos_ids=eos,
243+
)
244+
assert len(got) >= 2 # reached an accepted proposal, not just the bonus
245+
assert got[-1] in eos # stopped exactly on the stop token
246+
assert all(t not in eos for t in got[:-1]) # nothing emitted after EOS
247+
assert got == G[: len(got)] # lossless prefix
248+
249+
250+
def test_num_gen_one_returns_only_prefill_bonus():
251+
# A one-token request returns the free prefill bonus without a draft round.
252+
spec, target = _build()
253+
assert _speculative_decode(spec, _PROMPT, K=4, num_gen=1) == _greedy(
254+
target, _PROMPT, 1
255+
)
256+
257+
258+
if __name__ == "__main__":
259+
import pytest
260+
261+
raise SystemExit(pytest.main([__file__, "-q"]))

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ testpaths =
9898
examples/models/llava/test
9999
examples/models/eagle3/test_draft.py
100100
examples/models/eagle3/test_speculator.py
101+
examples/models/eagle3/test_e2e.py
101102
examples/models/gemma4_31b/test_eagle_tap.py
102103

103104
# exir

0 commit comments

Comments
 (0)