|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""End-to-end EAGLE-3 speculative-decoding smoke test (CPU, no export needed). |
| 8 | +
|
| 9 | +Builds tiny matching target + draft models and drives the shifted (vLLM-EAGLE) |
| 10 | +method-flow the C++ runner uses -- prefill -> draft chain -> target_verify -> |
| 11 | +accept -> reseed -> repeat -- checking the generated tokens equal greedy target |
| 12 | +decoding (lossless by construction). Forced-acceptance cases pin the partial, |
| 13 | +full, and accepted-EOS paths plus the one-token budget; the random-weight loop |
| 14 | +alone can leave them uncovered. |
| 15 | +
|
| 16 | +This is CPU eager coverage of the decoding *algorithm*, not the C++ runner |
| 17 | +itself: tokenizer integration, device buffers, CUDA-graph capture, and the real |
| 18 | +CUDA/AOTI export are exercised manually (examples/models/eagle3/export.py + the |
| 19 | +eagle3-cuda runner) and remain tracked as future automated CUDA coverage. |
| 20 | +""" |
| 21 | + |
| 22 | +import torch |
| 23 | + |
| 24 | +from executorch.examples.models.eagle3.draft import Eagle3Config, Eagle3Draft |
| 25 | +from executorch.examples.models.eagle3.speculator import Eagle3Speculator |
| 26 | +from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig |
| 27 | + |
| 28 | +_TARGET_VOCAB = 128 |
| 29 | + |
| 30 | + |
| 31 | +def _build(): |
| 32 | + torch.manual_seed(0) |
| 33 | + target = ( |
| 34 | + Gemma4_31B( |
| 35 | + Gemma4_31BConfig( |
| 36 | + vocab_size=_TARGET_VOCAB, |
| 37 | + hidden_size=32, |
| 38 | + intermediate_size=64, |
| 39 | + num_hidden_layers=6, |
| 40 | + num_attention_heads=4, |
| 41 | + num_key_value_heads=2, |
| 42 | + head_dim=8, |
| 43 | + num_global_key_value_heads=1, |
| 44 | + global_head_dim=8, |
| 45 | + sliding_window=64, |
| 46 | + max_seq_len=128, |
| 47 | + ) |
| 48 | + ) |
| 49 | + .to(torch.float32) |
| 50 | + .eval() |
| 51 | + ) |
| 52 | + draft = ( |
| 53 | + Eagle3Draft( |
| 54 | + Eagle3Config( |
| 55 | + hidden_size=32, |
| 56 | + target_hidden_size=32, |
| 57 | + intermediate_size=64, |
| 58 | + num_attention_heads=4, |
| 59 | + num_key_value_heads=2, |
| 60 | + head_dim=8, |
| 61 | + draft_vocab_size=64, |
| 62 | + target_vocab_size=_TARGET_VOCAB, |
| 63 | + aux_hidden_state_layers=[0, 1, 3], |
| 64 | + max_seq_len=128, |
| 65 | + has_own_embed=True, |
| 66 | + ) |
| 67 | + ) |
| 68 | + .to(torch.float32) |
| 69 | + .eval() |
| 70 | + ) |
| 71 | + return Eagle3Speculator(target, draft), target |
| 72 | + |
| 73 | + |
| 74 | +def _toks(ids): |
| 75 | + return torch.tensor([ids], dtype=torch.long) |
| 76 | + |
| 77 | + |
| 78 | +def _reset_kv(target): |
| 79 | + for name, buf in target.named_buffers(): |
| 80 | + if ".kv_cache." in name: |
| 81 | + buf.zero_() |
| 82 | + |
| 83 | + |
| 84 | +@torch.no_grad() |
| 85 | +def _greedy(target, prompt, n): |
| 86 | + seq, out = list(prompt), [] |
| 87 | + for _ in range(n): |
| 88 | + _reset_kv(target) |
| 89 | + logits, _ = target.forward_logits_taps( |
| 90 | + _toks(seq), torch.arange(len(seq)), last_logits_only=True |
| 91 | + ) |
| 92 | + t = int(logits[:, -1].argmax()) |
| 93 | + seq.append(t) |
| 94 | + out.append(t) |
| 95 | + return out |
| 96 | + |
| 97 | + |
| 98 | +def _accept_len(proposals, verify_ids): |
| 99 | + """Greedy acceptance: count leading proposals matching the verifier ids.""" |
| 100 | + a = 0 |
| 101 | + for j, p in enumerate(proposals): |
| 102 | + if p != int(verify_ids[0, j]): |
| 103 | + break |
| 104 | + a += 1 |
| 105 | + return a |
| 106 | + |
| 107 | + |
| 108 | +def _truncate_at_eos(tokens, eos_ids): |
| 109 | + """Cut at the first stop token (inclusive); returns (tokens, hit_eos).""" |
| 110 | + for i, t in enumerate(tokens): |
| 111 | + if t in eos_ids: |
| 112 | + return tokens[: i + 1], True |
| 113 | + return tokens, False |
| 114 | + |
| 115 | + |
| 116 | +@torch.no_grad() |
| 117 | +def _speculative_decode( |
| 118 | + spec, prompt, K, num_gen, force=None, eos_ids=None, accept_out=None |
| 119 | +): |
| 120 | + """The shifted one-target-forward-per-round loop the C++ runner implements. |
| 121 | +
|
| 122 | + ``force(emitted) -> list[K]`` overrides the draft's proposal *values* (the |
| 123 | + draft chain is still run to reseed) so tests can pin the acceptance count. |
| 124 | + ``eos_ids`` truncates a round at the first emitted stop token (matching the |
| 125 | + runner). Per-round acceptance counts are appended to ``accept_out``. |
| 126 | + """ |
| 127 | + target = spec.target |
| 128 | + _reset_kv(target) |
| 129 | + spec.draft.reset_cache() |
| 130 | + eos_ids = eos_ids or set() |
| 131 | + L = len(prompt) |
| 132 | + bonus, feat = spec.prefill(_toks(prompt), torch.arange(L)) |
| 133 | + anchor, anchor_pos = int(bonus), L |
| 134 | + emitted = [anchor] |
| 135 | + if num_gen <= 1 or anchor in eos_ids: |
| 136 | + return emitted[:num_gen] # prefill bonus suffices; no draft round runs |
| 137 | + |
| 138 | + def chain(seed_tokens, seed_feat, seed_pos): |
| 139 | + tids, g = spec.draft_decode(_toks(seed_tokens), seed_feat, seed_pos) |
| 140 | + proposals = [int(tids[0, -1])] |
| 141 | + last = int(seed_pos[-1]) |
| 142 | + tok, f = tids[:, -1:], g[:, -1:] |
| 143 | + for k in range(1, K): |
| 144 | + tids, g = spec.draft_decode(tok, f, torch.tensor([last + k])) |
| 145 | + proposals.append(int(tids[0, 0])) |
| 146 | + tok, f = tids, g |
| 147 | + return proposals |
| 148 | + |
| 149 | + proposals = chain(prompt[1:] + [anchor], feat, torch.arange(L)) |
| 150 | + if force is not None: |
| 151 | + proposals = force(emitted) |
| 152 | + while len(emitted) < num_gen: |
| 153 | + vids, vfeat = spec.target_verify( |
| 154 | + _toks([anchor] + proposals), torch.arange(anchor_pos, anchor_pos + K + 1) |
| 155 | + ) |
| 156 | + a = _accept_len(proposals, vids) |
| 157 | + if accept_out is not None: |
| 158 | + accept_out.append(a) |
| 159 | + corrected = int(vids[0, a]) |
| 160 | + new = (proposals[:a] + [corrected])[: num_gen - len(emitted)] |
| 161 | + new, hit_eos = _truncate_at_eos(new, eos_ids) |
| 162 | + emitted += new |
| 163 | + if hit_eos or len(emitted) >= num_gen: |
| 164 | + break |
| 165 | + proposals = chain( |
| 166 | + proposals[:a] + [corrected], |
| 167 | + vfeat[:, : a + 1], |
| 168 | + torch.arange(anchor_pos, anchor_pos + a + 1), |
| 169 | + ) |
| 170 | + anchor, anchor_pos = corrected, anchor_pos + 1 + a |
| 171 | + if force is not None: |
| 172 | + proposals = force(emitted) |
| 173 | + return emitted[:num_gen] |
| 174 | + |
| 175 | + |
| 176 | +_PROMPT = [2, 7, 3, 21, 9, 14] |
| 177 | + |
| 178 | + |
| 179 | +def test_speculative_decode_matches_greedy_e2e(): |
| 180 | + spec, target = _build() |
| 181 | + num_gen = 16 |
| 182 | + got = _speculative_decode(spec, _PROMPT, K=4, num_gen=num_gen) |
| 183 | + assert len(got) == num_gen |
| 184 | + assert got == _greedy(target, _PROMPT, num_gen) |
| 185 | + |
| 186 | + |
| 187 | +def test_full_acceptance_loop_is_lossless(): |
| 188 | + # Force every round to fully accept (a == K) by proposing the target's own |
| 189 | + # greedy continuation. This deterministically exercises the a == K reseed and |
| 190 | + # the folded-bonus path across rounds, which a random-weight run may never hit. |
| 191 | + spec, target = _build() |
| 192 | + K, num_gen = 4, 16 |
| 193 | + G = _greedy(target, _PROMPT, num_gen + K + 1) |
| 194 | + accepts = [] |
| 195 | + got = _speculative_decode( |
| 196 | + spec, |
| 197 | + _PROMPT, |
| 198 | + K=K, |
| 199 | + num_gen=num_gen, |
| 200 | + force=lambda emitted: G[len(emitted) : len(emitted) + K], |
| 201 | + accept_out=accepts, |
| 202 | + ) |
| 203 | + assert got == G[:num_gen] |
| 204 | + assert accepts and all(a == K for a in accepts) |
| 205 | + |
| 206 | + |
| 207 | +def test_partial_acceptance_loop_is_lossless(): |
| 208 | + # Force every round to accept K-1 (0 < a < K): greedy for the first K-1 |
| 209 | + # proposals, then a deliberately wrong token. The corrected token must be the |
| 210 | + # greedy token at the mismatch, so the loop stays lossless. |
| 211 | + spec, target = _build() |
| 212 | + K, num_gen = 4, 16 |
| 213 | + G = _greedy(target, _PROMPT, num_gen + K + 1) |
| 214 | + |
| 215 | + def force(emitted): |
| 216 | + e = len(emitted) |
| 217 | + good = G[e : e + K - 1] |
| 218 | + wrong = (G[e + K - 1] + 1) % _TARGET_VOCAB |
| 219 | + return good + [wrong] |
| 220 | + |
| 221 | + accepts = [] |
| 222 | + got = _speculative_decode( |
| 223 | + spec, _PROMPT, K=K, num_gen=num_gen, force=force, accept_out=accepts |
| 224 | + ) |
| 225 | + assert got == G[:num_gen] |
| 226 | + assert accepts and all(0 < a < K for a in accepts) |
| 227 | + |
| 228 | + |
| 229 | +def test_accepted_proposal_eos_stops_emission(): |
| 230 | + # An accepted proposal (not the prefill bonus or corrected token) that is a |
| 231 | + # stop token must end emission immediately, with nothing emitted after it. |
| 232 | + spec, target = _build() |
| 233 | + K, num_gen = 4, 16 |
| 234 | + G = _greedy(target, _PROMPT, num_gen + K + 1) |
| 235 | + eos = {G[2]} # the 3rd accepted token of the first full-acceptance round |
| 236 | + got = _speculative_decode( |
| 237 | + spec, |
| 238 | + _PROMPT, |
| 239 | + K=K, |
| 240 | + num_gen=num_gen, |
| 241 | + force=lambda emitted: G[len(emitted) : len(emitted) + K], |
| 242 | + eos_ids=eos, |
| 243 | + ) |
| 244 | + assert len(got) >= 2 # reached an accepted proposal, not just the bonus |
| 245 | + assert got[-1] in eos # stopped exactly on the stop token |
| 246 | + assert all(t not in eos for t in got[:-1]) # nothing emitted after EOS |
| 247 | + assert got == G[: len(got)] # lossless prefix |
| 248 | + |
| 249 | + |
| 250 | +def test_num_gen_one_returns_only_prefill_bonus(): |
| 251 | + # A one-token request returns the free prefill bonus without a draft round. |
| 252 | + spec, target = _build() |
| 253 | + assert _speculative_decode(spec, _PROMPT, K=4, num_gen=1) == _greedy( |
| 254 | + target, _PROMPT, 1 |
| 255 | + ) |
| 256 | + |
| 257 | + |
| 258 | +if __name__ == "__main__": |
| 259 | + import pytest |
| 260 | + |
| 261 | + raise SystemExit(pytest.main([__file__, "-q"])) |
0 commit comments