99Three methods are lowered together so they share mutable state:
1010 - "prefill": target prompt prefill (T in [get_min_prefill_chunk,
1111 get_max_prefill_chunk]) -> next token + fused feature.
12- - "target_verify": target forward over the candidate chain (static T=chain+1)
12+ - "target_verify": target forward over the candidate chain (dynamic T in
13+ [2, MATVEC_MAX_M] = K+1; --chain selects K at runtime)
1314 -> per-position greedy ids + fused feature.
1415 - "draft_decode": draft proposal over its KV cache (T>=1; seed with T>1, step
1516 with T=1) -> proposed target ids + recurrent feature.
3536supported.
3637
3738Scope (this is a fixed-shape ExecuTorch artifact, not a generic EAGLE runtime):
38- chain length, the chain_len+1 verify window, the prefill/draft dynamic ranges,
39- the CUDA backend, and the small-M INT4 dispatch policy are all baked at export —
40- vary the target, chain length, or backend by re-exporting. The caller is
41- responsible for pairing a target, draft, and tokenizer that were trained
42- together: only target/draft hidden size is checked here; tokenizer identity,
43- target vocab size, the d2t/t2d mapping, the tap-layer convention, and the draft's
44- training target are NOT validated, and a mismatch can pass export yet silently
45- degrade acceptance or correctness. A versioned target/draft/tokenizer manifest +
46- runtime validation is left as future work.
39+ the target, the prefill/draft/verify dynamic ranges, the CUDA backend, and the
40+ small-M INT4 dispatch policy are all baked at export — vary the target or backend
41+ by re-exporting. Chain length K is NOT baked: target_verify is dynamic over
42+ T in [2, MATVEC_MAX_M], so one .pte serves any K in [1, MATVEC_MAX_M - 1]
43+ (get_chain_len is only the default) and the runner selects K with --chain. The
44+ caller is responsible for pairing a target, draft, and tokenizer that were
45+ trained together: only target/draft hidden size is checked here; tokenizer
46+ identity, target vocab size, the d2t/t2d mapping, the tap-layer convention, and
47+ the draft's training target are NOT validated, and a mismatch can pass export yet
48+ silently degrade acceptance or correctness. A versioned target/draft/tokenizer
49+ manifest + runtime validation is left as future work.
4750"""
4851
4952import argparse
5760from executorch .examples .models .eagle3 .speculator import Eagle3Speculator
5861from executorch .examples .models .eagle3 .target import TARGETS
5962
60- # Route the static chain_len+1 verify forward to the small-M INT4 GEMM. Must be
61- # <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh) and >= the largest chain+1.
63+ # Route the verify forward to the small-M INT4 GEMM. target_verify is dynamic
64+ # over T in [2, _MATVEC_MAX_M] (chain_len+1 is only the export example), and the
65+ # whole range must be <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh).
6266# Set locally on int4_dispatch (not the global default) so other models' exports
6367# keep MATVEC_MAX_M=4 and their dynamic prefill ranges are unaffected.
6468_MATVEC_MAX_M = 8
@@ -139,8 +143,9 @@ def _lap(msg: str) -> None:
139143 hidden = spec .draft .config .hidden_size
140144 draft_vocab_size = spec .draft .config .draft_vocab_size
141145 # Verify re-feeds the last confirmed token (its logits are the folded bonus)
142- # plus the K proposals: a fixed chain_len+1 window in one target forward. With
143- # chain_len+1 <= MATVEC_MAX_M the verify forward stays on the small-M GEMM
146+ # plus the K proposals: a chain_len+1 window -- only the export example.
147+ # target_verify is lowered dynamic over T in [2, MATVEC_MAX_M], and with the
148+ # whole range <= MATVEC_MAX_M the verify forward stays on the small-M GEMM
144149 # rather than the dequant path.
145150 verify_len = chain_len + 1
146151 # prefill's dynamic length must take a single INT4 dispatch branch over its
@@ -165,10 +170,18 @@ def _lap(msg: str) -> None:
165170 )
166171 _lap ("export prefill" )
167172
168- print (f"Exporting target_verify (T = { verify_len } )..." )
173+ # Dynamic chain length: verify window T = K+1 dynamic in [2, MATVEC_MAX_M]
174+ # so K is a runtime parameter (one .pte serves K in [1, MATVEC_MAX_M-1], the
175+ # runner picks it with --chain). max == MATVEC_MAX_M so M never straddles the
176+ # INT4 dispatch threshold -> resolves to the small-M GEMM over the whole
177+ # range. min=2 is the K=1 window; the target's min_forward_len was a
178+ # conservative export note -- the gemma4 mask traces correctly down to T=2.
179+ verify_max = int4_dispatch .MATVEC_MAX_M
180+ verify_dim = Dim ("verify_len" , min = 2 , max = verify_max )
181+ print (f"Exporting target_verify (T in [2, { verify_max } ], example { verify_len } )..." )
169182 # The mid-M SDPA key bound is the dynamic length of kv_window: valid KV
170- # positions = anchor_pos + chain + 1, in [verify_len , max_seq_len].
171- kv_dim = Dim ("kv_len" , min = verify_len , max = target_config .max_seq_len )
183+ # positions = anchor_pos + K + 1, in [2 , max_seq_len].
184+ kv_dim = Dim ("kv_len" , min = 2 , max = target_config .max_seq_len )
172185 with torch .no_grad ():
173186 verify_ep = export (
174187 _TargetVerify (spec ),
@@ -177,7 +190,7 @@ def _lap(msg: str) -> None:
177190 torch .arange (verify_len , dtype = torch .long ),
178191 torch .zeros ((8 * verify_len ,), dtype = torch .int32 ),
179192 ),
180- dynamic_shapes = ({}, {}, {0 : kv_dim }),
193+ dynamic_shapes = ({1 : verify_dim }, {0 : verify_dim }, {0 : kv_dim }),
181194 strict = True ,
182195 )
183196 _lap ("export target_verify" )
@@ -359,38 +372,35 @@ def main() -> None:
359372 f"--max-prefill (got { args .max_prefill } ) or --max-seq-len (got "
360373 f"{ args .max_seq_len } )"
361374 )
362- # target_verify is a single static forward of chain+1 tokens: it must fit the
363- # small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's per-forward bounds
364- # [min_forward_len, max_forward].
375+ # target_verify is exported dynamic over T in [2, _MATVEC_MAX_M] (see
376+ # verify_dim), so --chain only sets the default/example K baked as
377+ # get_chain_len; one .pte serves any K in [1, _MATVEC_MAX_M - 1]. The example
378+ # K+1 must still fit the small-M GEMM (<= _MATVEC_MAX_M), the dynamic lower
379+ # bound (K >= 1 => window >= 2), and the target's per-forward max.
380+ # min_forward_len is a conservative prefill note and does NOT bound verify.
365381 verify_len = args .chain + 1
366382 if verify_len > _MATVEC_MAX_M :
367383 p .error (
368384 f"--chain { args .chain } (verify window { verify_len } ) exceeds the "
369385 f"INT4 small-M GEMM limit { _MATVEC_MAX_M } "
370386 )
371- if verify_len < spec_t . min_forward_len :
387+ if verify_len < 2 :
372388 p .error (
373389 f"--chain { args .chain } (verify window { verify_len } ) is below the "
374- f"target's minimum forward length { spec_t . min_forward_len } "
390+ f"minimum verify window of 2 (need --chain >= 1) "
375391 )
376392 if verify_len > min (args .max_seq_len - 1 , max_forward ):
377393 p .error (
378394 f"--chain { args .chain } (verify window { verify_len } ) exceeds the "
379395 f"target's per-forward limit { min (args .max_seq_len - 1 , max_forward )} "
380396 )
381- < << << << HEAD
382- # Route the static chain_len+1 verify forward to the small-M INT4 GEMM by
383- # raising the dispatch threshold for this export only; restore it so the
384- # process-global default (4) is unchanged for any later use.
385- == == == =
386397 # Route the verify forward (dynamic T in [2, _MATVEC_MAX_M]) to the small-M
387398 # custom ops by raising the dispatch thresholds for this export only; restore
388399 # them so the process-global defaults (4) are unchanged for any later use.
389400 # Both INT4 and INT8 must be raised: the target's tied lm_head runs in INT8
390401 # (the embedding is quantized to int8), so the all-position verify logits hit
391402 # the INT8 dispatch with M = verify_len. If only INT4 were raised, the INT8
392403 # branch would straddle M=4 and force a data-dependent guard on verify_len.
393- >> >> >> > b3dd6ec802 (fixup ! Add the EAGLE - 3 speculator CUDA export )
394404 import executorch .backends .cuda .quantize_op_dispatch .int4_dispatch as int4_dispatch
395405 import executorch .backends .cuda .quantize_op_dispatch .int8_dispatch as int8_dispatch
396406
0 commit comments