-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathdraft.py
More file actions
472 lines (411 loc) · 19.2 KB
/
Copy pathdraft.py
File metadata and controls
472 lines (411 loc) · 19.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""EAGLE-3 draft head for vLLM speculator checkpoints.
The draft model fuses target auxiliary hidden states with ``fc``, runs one
Llama-style decoder layer over token embeddings plus the fused feature, and
projects the midlayer output to reduced-vocabulary draft logits. The midlayer
output ``g`` is reused as the recurrent feature for drafting; ``fc`` is used
only for target auxiliary hidden states.
Draft ids map back to target ids with ``target_id = draft_id + d2t[draft_id]``.
Speculator checkpoints store the decoder layer under ``layers.0.*`` and may
include ``embed_tokens``, ``d2t``, and ``t2d``.
"""
import os
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from torch.nn import functional as F
@dataclass
class Eagle3Config:
hidden_size: int = 5376
target_hidden_size: int = 5376
intermediate_size: int = 21504
num_attention_heads: int = 32
num_key_value_heads: int = 16
head_dim: int = 256
rope_theta: float = 10_000.0
rms_norm_eps: float = 1e-6
draft_vocab_size: int = 32000
target_vocab_size: int = 262144
aux_hidden_state_layers: list = field(default_factory=lambda: [2, 30, 57])
# norm_before_residual: store the attention residual after hidden_norm.
# norm_before_fc: apply an RMSNorm over the concatenated aux features before
# fc (gpt-oss-style speculators checkpoints); not supported here.
# has_own_embed: the head ships its own embed_tokens (set during load).
norm_before_residual: bool = True
norm_before_fc: bool = False
has_own_embed: bool = False
max_seq_len: int = 4096
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class Eagle3KVCache(nn.Module):
"""Flat KV cache for the single EAGLE-3 draft decoder layer.
``update`` writes the new K/V at ``input_pos`` and returns the whole buffer;
an explicit causal mask (built by the draft) selects the valid positions, so
the same path serves both prefill (T>1) and single-step draft decode (T=1).
"""
def __init__(
self,
max_batch_size: int,
max_seq_len: int,
num_kv_heads: int,
head_dim: int,
):
super().__init__()
shape = (max_batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer("k_cache", torch.zeros(shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(shape), persistent=False)
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
self.k_cache.index_copy_(2, input_pos, k_val)
self.v_cache.index_copy_(2, input_pos, v_val)
return self.k_cache, self.v_cache
def allocate(self, dtype: torch.dtype, device) -> None:
"""Re-register the cache buffers with a given dtype/device (zeroed)."""
shape = self.k_cache.shape
self.register_buffer(
"k_cache", torch.zeros(shape, dtype=dtype, device=device), persistent=False
)
self.register_buffer(
"v_cache", torch.zeros(shape, dtype=dtype, device=device), persistent=False
)
def reset(self) -> None:
self.k_cache.zero_()
self.v_cache.zero_()
class Eagle3Attention(nn.Module):
"""Llama GQA attention; q/k/v project from the doubled-width (2*hidden) input."""
def __init__(self, config: Eagle3Config):
super().__init__()
self.n_heads = config.num_attention_heads
self.n_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
in_dim = 2 * config.hidden_size
self.q_proj = nn.Linear(in_dim, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(in_dim, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(in_dim, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(
self.n_heads * self.head_dim, config.hidden_size, bias=False
)
inv_freq = 1.0 / (
config.rope_theta
** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.kv_cache = Eagle3KVCache(
max_batch_size=1,
max_seq_len=config.max_seq_len,
num_kv_heads=self.n_kv_heads,
head_dim=self.head_dim,
)
def _project_rope(
self, x: torch.Tensor, positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
freqs = torch.outer(positions.float(), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(q.dtype)
sin = emb.sin().to(q.dtype)
q = q * cos + _rotate_half(q) * sin
k = k * cos + _rotate_half(k) * sin
return q, k, v
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
"""Causal attention.
With ``attn_mask=None`` runs stateless over the full sequence
(``is_causal``). With an explicit mask, K/V are written to the KV cache
at ``positions`` and read back, so the same call serves prefill and
single-step draft decode.
"""
B, T, _ = x.shape
q, k, v = self._project_rope(x, positions)
if attn_mask is None:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
else:
k, v = self.kv_cache.update(positions, k, v)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, enable_gqa=True
)
y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
return self.o_proj(y)
class Eagle3MLP(nn.Module):
def __init__(self, config: Eagle3Config):
super().__init__()
self.gate_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=False
)
self.up_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=False
)
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class Eagle3Midlayer(nn.Module):
"""Single EAGLE-3 decoder layer with dual input norms over two streams."""
def __init__(self, config: Eagle3Config):
super().__init__()
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hidden_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = Eagle3Attention(config)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.mlp = Eagle3MLP(config)
self.norm_before_residual = config.norm_before_residual
def forward(
self,
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
normed_embeds = self.input_layernorm(input_embeds)
normed_feature = self.hidden_norm(feature)
residual = normed_feature if self.norm_before_residual else feature
x = torch.cat((normed_embeds, normed_feature), dim=-1)
x = self.self_attn(x, positions, attn_mask)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
return residual + x
class Eagle3Draft(nn.Module):
def __init__(self, config: Eagle3Config):
super().__init__()
self.config = config
self.fc = nn.Linear(
len(config.aux_hidden_state_layers) * config.target_hidden_size,
config.hidden_size,
bias=False,
)
self.midlayer = Eagle3Midlayer(config)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(
config.hidden_size, config.draft_vocab_size, bias=False
)
if config.has_own_embed:
self.embed_tokens = nn.Embedding(
config.target_vocab_size, config.hidden_size
)
# d2t/t2d are loaded from the checkpoint (assign=True adopts their
# shapes/dtypes): d2t[draft_id] is the offset to the target vocab id;
# t2d masks which target ids are in the draft vocab.
self.register_buffer(
"d2t",
torch.zeros(config.draft_vocab_size, dtype=torch.long),
persistent=False,
)
self.register_buffer("t2d", torch.zeros(1, dtype=torch.bool), persistent=False)
# cache_positions[i] = i; used to build the causal mask over the KV cache
# without introducing dynamic-shape index tensors at runtime.
self.register_buffer(
"cache_positions",
torch.arange(config.max_seq_len, dtype=torch.long),
persistent=False,
)
# Eager-only end of the valid contiguous cache prefix (see forward_cached).
self._cache_valid_end = 0
def fuse(self, aux: torch.Tensor) -> torch.Tensor:
"""Fuse concatenated target aux hidden states (B,T,3*D) -> feature (B,T,D)."""
return self.fc(aux)
def embed(self, ids: torch.Tensor) -> torch.Tensor:
"""Embed token ids with the head's own table.
Only valid when the checkpoint shipped its own ``embed_tokens``; heads
that reuse the target embedding must source embeddings from the target.
"""
if not self.config.has_own_embed:
raise RuntimeError(
"this draft head has no own embed_tokens (has_own_embed=False); "
"provide token embeddings from the target model instead"
)
return self.embed_tokens(ids)
def forward(
self,
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the midlayer over a sequence.
``attn_mask=None`` runs stateless over the full sequence; an explicit
mask uses the incremental KV cache (see ``forward_cached``).
Returns (draft_logits, g):
draft_logits: (B, T, draft_vocab_size) over the reduced vocab.
g: (B, T, hidden) midlayer output — the recurrent feature.
"""
g = self.midlayer(input_embeds, feature, positions, attn_mask)
draft_logits = self.lm_head(self.norm(g))
return draft_logits, g
def _build_causal_mask(self, positions: torch.Tensor) -> torch.Tensor:
"""Boolean (1, 1, T, max_seq_len) causal mask (True = attend).
Query position p attends to cache slot j iff j <= p. This is correct
only under the contiguous-from-0 invariant of ``forward_cached``: a query
at p attends to slots 0..p, all of which must already hold this
sequence's K/V. Rejected speculative tokens sit at slots > p (the next
query's p shrinks on rollback) and are excluded by the causal bound, so
they need no extra masking. A non-contiguous seed (e.g. writing only
slot 10 after reset) would wrongly attend to the zeroed slots 0..9 — see
``forward_cached``.
"""
q_pos = positions.unsqueeze(1) # (T, 1)
cache_pos = self.cache_positions.unsqueeze(0) # (1, max_seq_len)
return (q_pos >= cache_pos).unsqueeze(0).unsqueeze(0)
def forward_cached(
self,
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""KV-cached forward for prefill (T>1) and single-step draft decode (T=1).
Writes K/V at ``positions`` and attends over the cache. Like the target's
KV cache, attention scores against the whole ``max_seq_len`` buffer under
a static causal mask (export-friendly, no data-dependent slicing); what
this avoids versus a full recompute is re-running the prefix's
projections and MLP, not the attention score width.
Invariant: writes must be contiguous from position 0. Seed a fresh
sequence (after ``reset_cache``) starting at position 0 and only ever
extend with the next contiguous positions; offset or gapped seeds attend
to unwritten (zeroed) slots and are not supported. Batch size must be 1.
"""
if input_embeds.shape[0] != 1:
raise ValueError("forward_cached supports batch size 1 only")
if not torch.compiler.is_compiling():
self._validate_contiguous(positions)
return self.forward(
input_embeds, feature, positions, self._build_causal_mask(positions)
)
def _validate_contiguous(self, positions: torch.Tensor) -> None:
"""Eager-only guard for the contiguous-from-0 cache invariant.
Tracks the end of the valid contiguous prefix (reset by ``reset_cache``).
A write may overwrite already-written slots (speculative rollback) but
must be contiguous and ascending and must not start beyond the valid
prefix, which would leave unwritten (zeroed) slots below it in the
attention window. A rollback overwrite truncates the valid prefix to the
end of the write, so a slot above it is treated as stale and a later
write that skips it is rejected until it is rewritten. Skipped under
export/compile, where positions are traced tensors and the runner owns
the contract.
"""
start = int(positions[0])
length = int(positions.shape[0])
expected = torch.arange(start, start + length, device=positions.device)
if not torch.equal(positions, expected):
raise ValueError(
f"forward_cached positions must be contiguous ascending, "
f"got {positions.tolist()}"
)
if start > self._cache_valid_end:
raise ValueError(
f"non-contiguous cache seed: positions start at {start} but only "
f"{self._cache_valid_end} slot(s) are valid; seed from 0 after "
f"reset_cache"
)
# A write defines the valid prefix up to its end; slots above it (from an
# earlier longer write that this one rolled back) are now stale.
self._cache_valid_end = start + length
def reset_cache(self) -> None:
self.midlayer.self_attn.kv_cache.reset()
self._cache_valid_end = 0
def allocate_kv_cache(self, dtype: torch.dtype, device) -> None:
"""(Re)allocate the KV cache in a given dtype/device (zeroed)."""
self.midlayer.self_attn.kv_cache.allocate(dtype, device)
def draft_to_target(self, draft_ids: torch.Tensor) -> torch.Tensor:
return draft_ids + self.d2t[draft_ids]
@staticmethod
def from_checkpoint(
model_dir: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_seq_len: int = 4096,
) -> tuple["Eagle3Draft", Eagle3Config]:
import json
with open(os.path.join(model_dir, "config.json")) as f:
cfg = json.load(f)
tlc = cfg["transformer_layer_config"]
config = Eagle3Config(
hidden_size=tlc["hidden_size"],
target_hidden_size=cfg.get("target_hidden_size") or tlc["hidden_size"],
intermediate_size=tlc["intermediate_size"],
num_attention_heads=tlc["num_attention_heads"],
num_key_value_heads=tlc["num_key_value_heads"],
head_dim=tlc["head_dim"],
rope_theta=tlc["rope_parameters"]["rope_theta"],
rms_norm_eps=tlc["rms_norm_eps"],
draft_vocab_size=cfg["draft_vocab_size"],
target_vocab_size=tlc.get("vocab_size", 262144),
aux_hidden_state_layers=cfg["eagle_aux_hidden_state_layer_ids"],
norm_before_residual=cfg.get("norm_before_residual", False),
norm_before_fc=cfg.get("norm_before_fc", False),
max_seq_len=max_seq_len,
)
if config.norm_before_fc:
# This checkpoint variant requires an input RMSNorm before fc.
raise ValueError(
"norm_before_fc=True checkpoints are not supported "
"(would need an input RMSNorm before fc)"
)
raw = _load_safetensors(model_dir)
config.has_own_embed = "embed_tokens.weight" in raw
# Cast checkpoint weights after module construction so inv_freq stays fp32.
model = Eagle3Draft(config)
# The single decoder layer is stored as layers.0.* on disk.
state_dict = {
(k.replace("layers.0.", "midlayer.") if k.startswith("layers.0.") else k): (
v.to(dtype) if v.is_floating_point() else v
)
for k, v in raw.items()
}
# d2t/t2d are index/mask tensors (their checkpoint shape differs from the
# placeholder buffers); register them directly, load the rest strict.
model.register_buffer("d2t", state_dict.pop("d2t"), persistent=False)
model.register_buffer("t2d", state_dict.pop("t2d"), persistent=False)
model.load_state_dict(state_dict, strict=True, assign=True)
# Allocate the KV cache directly in the compute dtype on the target
# device *before* moving weights, so the float32 placeholder cache from
# __init__ is freed without ever being copied to the device. The
# subsequent .to(device) is a no-op for the (already-placed) cache and
# carries no dtype argument, so inv_freq stays float32.
model.allocate_kv_cache(dtype, device)
model = model.to(device)
assert (
model.midlayer.self_attn.inv_freq.dtype == torch.float32
), "RoPE inv_freq must remain float32"
return model.eval(), config
def _load_safetensors(model_dir: str) -> dict:
"""Load a monolithic or sharded safetensors checkpoint to CPU tensors."""
import json
from safetensors import safe_open
index = os.path.join(model_dir, "model.safetensors.index.json")
mono = os.path.join(model_dir, "model.safetensors")
if os.path.exists(mono):
shards = ["model.safetensors"]
elif os.path.exists(index):
with open(index) as f:
shards = sorted(set(json.load(f)["weight_map"].values()))
else:
raise FileNotFoundError(
f"no model.safetensors or model.safetensors.index.json in {model_dir}"
)
raw = {}
for shard in shards:
with safe_open(
os.path.join(model_dir, shard), framework="pt", device="cpu"
) as f:
for k in f.keys():
if k in raw:
raise ValueError(f"duplicate tensor {k!r} across shards ({shard})")
raw[k] = f.get_tensor(k)
return raw