Skip to content

Commit 5ac44d4

Browse files
committed
Make EAGLE-3 chain length K a runtime parameter
target_verify was exported at a static T = chain+1, so K was baked into the .pte and changing it meant a full re-export. Export the verify window as a dynamic dim T in [2, MATVEC_MAX_M] instead, so one .pte serves any K in [1, MATVEC_MAX_M - 1] and the runner selects it with --chain (get_chain_len is only the default). The verify M never straddles the INT4 dispatch threshold (max == MATVEC_MAX_M), so it resolves to the small-M GEMM over the whole range, and the mid-M SDPA kernel takes M as a runtime arg. The gemma4 mask traces correctly down to T=2; the target's min_forward_len=5 was a conservative prefill note and does not bound verify. Authored with assistance from Claude Code. ghstack-source-id: 8cc7587 ghstack-comment-id: 4734205542 Pull-Request: #20345
1 parent f4ba28a commit 5ac44d4

5 files changed

Lines changed: 67 additions & 43 deletions

File tree

backends/cuda/triton/kernels/sdpa_midm.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@
4444
# path is appropriate (enough rows to amortize a tiled kernel).
4545
MIDM_MAX_M = 8
4646

47-
# Number of key-range partitions for split-K. The verify method exports a static
48-
# M / B / H / D, so the partial buffers and grid are static-shaped; only the
47+
# Number of key-range partitions for split-K. B / H / D are static for the
48+
# exported verify method; M is the dynamic verify length (bounded by MIDM_MAX_M,
49+
# BLOCK_M covers it), so the grid (NUM_SPLITS x B*H) is static-shaped; the
4950
# per-split chunk size (derived from the dynamic valid_len) is a runtime scalar.
5051
# 32 splits x (B*H) heads gives ~1K CTAs at the gemma4 global shape -- ample
5152
# occupancy on an A100 while keeping the fp32 partials small.
@@ -89,9 +90,9 @@ def _sdpa_midm_splitk_kernel(
8990
valid_len,
9091
chunk_size,
9192
scale,
93+
M,
9294
H: tl.constexpr,
9395
HKV: tl.constexpr,
94-
M: tl.constexpr,
9596
D: tl.constexpr,
9697
BLOCK_M: tl.constexpr,
9798
BLOCK_N: tl.constexpr,
@@ -200,8 +201,8 @@ def _sdpa_midm_reduce_kernel(
200201
soh,
201202
som,
202203
sod,
204+
M,
203205
NUM_SPLITS: tl.constexpr,
204-
M: tl.constexpr,
205206
D: tl.constexpr,
206207
BLOCK_M: tl.constexpr,
207208
):
@@ -267,13 +268,17 @@ def _sdpa_midm_op(
267268
268269
``valid_len`` (max valid position + 1) bounds the key range; it is split into
269270
NUM_SPLITS chunks of ``chunk_size`` keys computed in parallel, then reduced.
270-
M / B / H / D are static for the exported verify method, so only chunk_size is
271-
a runtime (backed-SymInt) scalar -- the grid and partial buffers are static.
271+
B / H / D are static for the exported verify method; M is the dynamic verify
272+
length (bounded by MIDM_MAX_M). chunk_size (from the dynamic valid_len) is a
273+
runtime (backed-SymInt) scalar; the grid (NUM_SPLITS x B*H) is static.
272274
"""
273275
B, H, M, D = q.shape
274276
HKV = k.shape[1]
275277
out = torch.empty_like(q)
276-
BLOCK_M = max(16, triton.next_power_of_2(M))
278+
# M <= MIDM_MAX_M (8) => next_pow2(M) <= 8 => max(16, .) is always 16. Hardcode
279+
# so M can be a runtime (dynamic verify) dim -- next_power_of_2 can't take a
280+
# SymInt, and M is a kernel runtime arg used only for the offs_m < M masks.
281+
BLOCK_M = 16
277282
# gemma4 global layers use D=512; a wide key tile + pipelining overflow SMEM
278283
# there, so shrink both. Small D can afford more.
279284
BLOCK_N, num_stages = (32, 1) if D >= 512 else (64, 2)
@@ -381,8 +386,9 @@ def midm_sdpa(
381386
) -> torch.Tensor:
382387
"""Dispatch: the mid-M op for a small query window when enabled; otherwise
383388
the standard F.sdpa the model already uses (which the replacement pass swaps
384-
for triton::sdpa). M is static per exported method, so the branch resolves at
385-
trace time. ``valid_len`` is the shared per-forward key bound."""
389+
for triton::sdpa). M (q.shape[2]) is the dynamic verify length; its exported
390+
range [2, MIDM_MAX_M] satisfies this guard, so the branch resolves at export.
391+
``valid_len`` is the shared per-forward key bound."""
386392
M = q.shape[2]
387393
if enable and 2 <= M <= MIDM_MAX_M:
388394
return sdpa_midm(q, k, v, input_pos, scale, valid_len=valid_len)

examples/models/eagle3/export.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
Three methods are lowered together so they share mutable state:
1010
- "prefill": target prompt prefill (T in [get_min_prefill_chunk,
1111
get_max_prefill_chunk]) -> next token + fused feature.
12-
- "target_verify": target forward over the candidate chain (static T=chain+1)
12+
- "target_verify": target forward over the candidate chain (dynamic T in
13+
[2, MATVEC_MAX_M] = K+1; --chain selects K at runtime)
1314
-> per-position greedy ids + fused feature.
1415
- "draft_decode": draft proposal over its KV cache (T>=1; seed with T>1, step
1516
with T=1) -> proposed target ids + recurrent feature.
@@ -35,15 +36,17 @@
3536
supported.
3637
3738
Scope (this is a fixed-shape ExecuTorch artifact, not a generic EAGLE runtime):
38-
chain length, the chain_len+1 verify window, the prefill/draft dynamic ranges,
39-
the CUDA backend, and the small-M INT4 dispatch policy are all baked at export —
40-
vary the target, chain length, or backend by re-exporting. The caller is
41-
responsible for pairing a target, draft, and tokenizer that were trained
42-
together: only target/draft hidden size is checked here; tokenizer identity,
43-
target vocab size, the d2t/t2d mapping, the tap-layer convention, and the draft's
44-
training target are NOT validated, and a mismatch can pass export yet silently
45-
degrade acceptance or correctness. A versioned target/draft/tokenizer manifest +
46-
runtime validation is left as future work.
39+
the target, the prefill/draft/verify dynamic ranges, the CUDA backend, and the
40+
small-M INT4 dispatch policy are all baked at export — vary the target or backend
41+
by re-exporting. Chain length K is NOT baked: target_verify is dynamic over
42+
T in [2, MATVEC_MAX_M], so one .pte serves any K in [1, MATVEC_MAX_M - 1]
43+
(get_chain_len is only the default) and the runner selects K with --chain. The
44+
caller is responsible for pairing a target, draft, and tokenizer that were
45+
trained together: only target/draft hidden size is checked here; tokenizer
46+
identity, target vocab size, the d2t/t2d mapping, the tap-layer convention, and
47+
the draft's training target are NOT validated, and a mismatch can pass export yet
48+
silently degrade acceptance or correctness. A versioned target/draft/tokenizer
49+
manifest + runtime validation is left as future work.
4750
"""
4851

4952
import argparse
@@ -57,8 +60,9 @@
5760
from executorch.examples.models.eagle3.speculator import Eagle3Speculator
5861
from executorch.examples.models.eagle3.target import TARGETS
5962

60-
# Route the static chain_len+1 verify forward to the small-M INT4 GEMM. Must be
61-
# <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh) and >= the largest chain+1.
63+
# Route the verify forward to the small-M INT4 GEMM. target_verify is dynamic
64+
# over T in [2, _MATVEC_MAX_M] (chain_len+1 is only the export example), and the
65+
# whole range must be <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh).
6266
# Set locally on int4_dispatch (not the global default) so other models' exports
6367
# keep MATVEC_MAX_M=4 and their dynamic prefill ranges are unaffected.
6468
_MATVEC_MAX_M = 8
@@ -139,8 +143,9 @@ def _lap(msg: str) -> None:
139143
hidden = spec.draft.config.hidden_size
140144
draft_vocab_size = spec.draft.config.draft_vocab_size
141145
# Verify re-feeds the last confirmed token (its logits are the folded bonus)
142-
# plus the K proposals: a fixed chain_len+1 window in one target forward. With
143-
# chain_len+1 <= MATVEC_MAX_M the verify forward stays on the small-M GEMM
146+
# plus the K proposals: a chain_len+1 window -- only the export example.
147+
# target_verify is lowered dynamic over T in [2, MATVEC_MAX_M], and with the
148+
# whole range <= MATVEC_MAX_M the verify forward stays on the small-M GEMM
144149
# rather than the dequant path.
145150
verify_len = chain_len + 1
146151
# prefill's dynamic length must take a single INT4 dispatch branch over its
@@ -165,10 +170,18 @@ def _lap(msg: str) -> None:
165170
)
166171
_lap("export prefill")
167172

168-
print(f"Exporting target_verify (T = {verify_len})...")
173+
# Dynamic chain length: verify window T = K+1 dynamic in [2, MATVEC_MAX_M]
174+
# so K is a runtime parameter (one .pte serves K in [1, MATVEC_MAX_M-1], the
175+
# runner picks it with --chain). max == MATVEC_MAX_M so M never straddles the
176+
# INT4 dispatch threshold -> resolves to the small-M GEMM over the whole
177+
# range. min=2 is the K=1 window; the target's min_forward_len was a
178+
# conservative export note -- the gemma4 mask traces correctly down to T=2.
179+
verify_max = int4_dispatch.MATVEC_MAX_M
180+
verify_dim = Dim("verify_len", min=2, max=verify_max)
181+
print(f"Exporting target_verify (T in [2, {verify_max}], example {verify_len})...")
169182
# The mid-M SDPA key bound is the dynamic length of kv_window: valid KV
170-
# positions = anchor_pos + chain + 1, in [verify_len, max_seq_len].
171-
kv_dim = Dim("kv_len", min=verify_len, max=target_config.max_seq_len)
183+
# positions = anchor_pos + K + 1, in [2, max_seq_len].
184+
kv_dim = Dim("kv_len", min=2, max=target_config.max_seq_len)
172185
with torch.no_grad():
173186
verify_ep = export(
174187
_TargetVerify(spec),
@@ -177,7 +190,7 @@ def _lap(msg: str) -> None:
177190
torch.arange(verify_len, dtype=torch.long),
178191
torch.zeros((8 * verify_len,), dtype=torch.int32),
179192
),
180-
dynamic_shapes=({}, {}, {0: kv_dim}),
193+
dynamic_shapes=({1: verify_dim}, {0: verify_dim}, {0: kv_dim}),
181194
strict=True,
182195
)
183196
_lap("export target_verify")
@@ -359,38 +372,35 @@ def main() -> None:
359372
f"--max-prefill (got {args.max_prefill}) or --max-seq-len (got "
360373
f"{args.max_seq_len})"
361374
)
362-
# target_verify is a single static forward of chain+1 tokens: it must fit the
363-
# small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's per-forward bounds
364-
# [min_forward_len, max_forward].
375+
# target_verify is exported dynamic over T in [2, _MATVEC_MAX_M] (see
376+
# verify_dim), so --chain only sets the default/example K baked as
377+
# get_chain_len; one .pte serves any K in [1, _MATVEC_MAX_M - 1]. The example
378+
# K+1 must still fit the small-M GEMM (<= _MATVEC_MAX_M), the dynamic lower
379+
# bound (K >= 1 => window >= 2), and the target's per-forward max.
380+
# min_forward_len is a conservative prefill note and does NOT bound verify.
365381
verify_len = args.chain + 1
366382
if verify_len > _MATVEC_MAX_M:
367383
p.error(
368384
f"--chain {args.chain} (verify window {verify_len}) exceeds the "
369385
f"INT4 small-M GEMM limit {_MATVEC_MAX_M}"
370386
)
371-
if verify_len < spec_t.min_forward_len:
387+
if verify_len < 2:
372388
p.error(
373389
f"--chain {args.chain} (verify window {verify_len}) is below the "
374-
f"target's minimum forward length {spec_t.min_forward_len}"
390+
f"minimum verify window of 2 (need --chain >= 1)"
375391
)
376392
if verify_len > min(args.max_seq_len - 1, max_forward):
377393
p.error(
378394
f"--chain {args.chain} (verify window {verify_len}) exceeds the "
379395
f"target's per-forward limit {min(args.max_seq_len - 1, max_forward)}"
380396
)
381-
<<<<<<< HEAD
382-
# Route the static chain_len+1 verify forward to the small-M INT4 GEMM by
383-
# raising the dispatch threshold for this export only; restore it so the
384-
# process-global default (4) is unchanged for any later use.
385-
=======
386397
# Route the verify forward (dynamic T in [2, _MATVEC_MAX_M]) to the small-M
387398
# custom ops by raising the dispatch thresholds for this export only; restore
388399
# them so the process-global defaults (4) are unchanged for any later use.
389400
# Both INT4 and INT8 must be raised: the target's tied lm_head runs in INT8
390401
# (the embedding is quantized to int8), so the all-position verify logits hit
391402
# the INT8 dispatch with M = verify_len. If only INT4 were raised, the INT8
392403
# branch would straddle M=4 and force a data-dependent guard on verify_len.
393-
>>>>>>> b3dd6ec802 (fixup! Add the EAGLE-3 speculator CUDA export)
394404
import executorch.backends.cuda.quantize_op_dispatch.int4_dispatch as int4_dispatch
395405
import executorch.backends.cuda.quantize_op_dispatch.int8_dispatch as int8_dispatch
396406

examples/models/eagle3/main.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ DEFINE_bool(
109109
"current export feeds target_verify a kv_window whose length changes every "
110110
"round, so capture is unsafe (stale-shape replay). Only enable for an "
111111
"export whose target_verify inputs all have stable shapes.");
112+
DEFINE_int32(
113+
chain,
114+
-1,
115+
"Override chain length K at runtime (<=0 uses the .pte's get_chain_len). "
116+
"Requires a dynamic-T verify export; clamped to [1, 7] (verify M=K+1<=8).");
112117
// Chat template + stop tokens default to Gemma 4 IT; override for other models.
113118
DEFINE_string(
114119
chat_prefix,
@@ -264,7 +269,8 @@ int main(int argc, char** argv) {
264269
const int64_t max_prefill = meta("get_max_prefill_chunk");
265270
const int64_t min_prefill = meta("get_min_prefill_chunk");
266271
const int64_t max_seq_len = meta("get_max_seq_len");
267-
const int64_t K = chain_len;
272+
const int64_t K_req = (FLAGS_chain > 0) ? FLAGS_chain : chain_len;
273+
const int64_t K = (K_req < 1) ? 1 : (K_req > 7 ? 7 : K_req);
268274

269275
// EOS: tokenizer/metadata ids, the configured eos, any --stop_ids, and the
270276
// encoded --stop_token delimiter (all default to the Gemma 4 IT conventions).

examples/models/eagle3/target.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ class TargetSpec:
7373
# config -> max tokens accepted in one target forward (e.g. a sliding ring
7474
# buffer caps it at 2*window; a flat-cache model uses max_seq_len-1).
7575
max_forward_len: Callable[[Any], int]
76-
# Minimum tokens in ANY single target forward the export accepts (some
76+
# Minimum tokens the export specializes for a target forward (some
7777
# attention-mask implementations specialize a lower bound under
78-
# torch.export). Applies to both prefill and the static target_verify window.
78+
# torch.export). Bounds prefill only; target_verify is exported dynamic over
79+
# T in [2, MATVEC_MAX_M] and is not constrained by it.
7980
min_forward_len: int
8081

8182

examples/models/gemma4_31b/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,9 @@ def forward(
384384
# layers; falls back to F.sdpa otherwise (M==1 decode, large-M prefill,
385385
# sliding layers, or when disabled). Imported lazily and only when
386386
# enabled so a CPU / non-mid-M import of the model never pulls in triton
387-
# or the CUDA backend. M is static per exported method, so the mid-M
388-
# branch resolves at trace time.
387+
# or the CUDA backend. M (the verify window) is the dynamic verify length
388+
# bounded to [2, MIDM_MAX_M] by the export, so the mid-M branch resolves
389+
# at trace time.
389390
if self.use_midm_sdpa:
390391
from executorch.backends.cuda.triton.kernels.sdpa_midm import midm_sdpa
391392

0 commit comments

Comments
 (0)