-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathdraft.py
More file actions
292 lines (250 loc) · 11.4 KB
/
Copy pathdraft.py
File metadata and controls
292 lines (250 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""EAGLE-3 draft head for vLLM speculator checkpoints.
The draft model fuses target auxiliary hidden states with ``fc``, runs one
Llama-style decoder layer over token embeddings plus the fused feature, and
projects the midlayer output to reduced-vocabulary draft logits. The midlayer
output ``g`` is reused as the recurrent feature for drafting; ``fc`` is used
only for target auxiliary hidden states.
Draft ids map back to target ids with ``target_id = draft_id + d2t[draft_id]``.
Speculator checkpoints store the decoder layer under ``layers.0.*`` and may
include ``embed_tokens``, ``d2t``, and ``t2d``.
"""
import os
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from torch.nn import functional as F
@dataclass
class Eagle3Config:
hidden_size: int = 5376
target_hidden_size: int = 5376
intermediate_size: int = 21504
num_attention_heads: int = 32
num_key_value_heads: int = 16
head_dim: int = 256
rope_theta: float = 10_000.0
rms_norm_eps: float = 1e-6
draft_vocab_size: int = 32000
target_vocab_size: int = 262144
aux_hidden_state_layers: list = field(default_factory=lambda: [2, 30, 57])
# norm_before_residual: store the attention residual after hidden_norm.
# norm_before_fc: apply an RMSNorm over the concatenated aux features before
# fc (gpt-oss-style speculators checkpoints); not supported here.
# has_own_embed: the head ships its own embed_tokens (set during load).
norm_before_residual: bool = True
norm_before_fc: bool = False
has_own_embed: bool = False
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class Eagle3Attention(nn.Module):
"""Llama GQA attention; q/k/v project from the doubled-width (2*hidden) input."""
def __init__(self, config: Eagle3Config):
super().__init__()
self.n_heads = config.num_attention_heads
self.n_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
in_dim = 2 * config.hidden_size
self.q_proj = nn.Linear(in_dim, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(in_dim, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(in_dim, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(
self.n_heads * self.head_dim, config.hidden_size, bias=False
)
inv_freq = 1.0 / (
config.rope_theta
** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
freqs = torch.outer(positions.float(), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(q.dtype)
sin = emb.sin().to(q.dtype)
q = q * cos + _rotate_half(q) * sin
k = k * cos + _rotate_half(k) * sin
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
return self.o_proj(y)
class Eagle3MLP(nn.Module):
def __init__(self, config: Eagle3Config):
super().__init__()
self.gate_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=False
)
self.up_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=False
)
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class Eagle3Midlayer(nn.Module):
"""Single EAGLE-3 decoder layer with dual input norms over two streams."""
def __init__(self, config: Eagle3Config):
super().__init__()
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hidden_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = Eagle3Attention(config)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.mlp = Eagle3MLP(config)
self.norm_before_residual = config.norm_before_residual
def forward(
self,
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
normed_embeds = self.input_layernorm(input_embeds)
normed_feature = self.hidden_norm(feature)
residual = normed_feature if self.norm_before_residual else feature
x = torch.cat((normed_embeds, normed_feature), dim=-1)
x = self.self_attn(x, positions)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
return residual + x
class Eagle3Draft(nn.Module):
def __init__(self, config: Eagle3Config):
super().__init__()
self.config = config
self.fc = nn.Linear(
len(config.aux_hidden_state_layers) * config.target_hidden_size,
config.hidden_size,
bias=False,
)
self.midlayer = Eagle3Midlayer(config)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(
config.hidden_size, config.draft_vocab_size, bias=False
)
if config.has_own_embed:
self.embed_tokens = nn.Embedding(
config.target_vocab_size, config.hidden_size
)
# d2t/t2d are loaded from the checkpoint (assign=True adopts their
# shapes/dtypes): d2t[draft_id] is the offset to the target vocab id;
# t2d masks which target ids are in the draft vocab.
self.register_buffer(
"d2t",
torch.zeros(config.draft_vocab_size, dtype=torch.long),
persistent=False,
)
self.register_buffer("t2d", torch.zeros(1, dtype=torch.bool), persistent=False)
def fuse(self, aux: torch.Tensor) -> torch.Tensor:
"""Fuse concatenated target aux hidden states (B,T,3*D) -> feature (B,T,D)."""
return self.fc(aux)
def embed(self, ids: torch.Tensor) -> torch.Tensor:
"""Embed token ids with the head's own table.
Only valid when the checkpoint shipped its own ``embed_tokens``; heads
that reuse the target embedding must source embeddings from the target.
"""
if not self.config.has_own_embed:
raise RuntimeError(
"this draft head has no own embed_tokens (has_own_embed=False); "
"provide token embeddings from the target model instead"
)
return self.embed_tokens(ids)
def forward(
self,
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the midlayer over a sequence.
Returns (draft_logits, g):
draft_logits: (B, T, draft_vocab_size) over the reduced vocab.
g: (B, T, hidden) midlayer output — the recurrent feature.
"""
g = self.midlayer(input_embeds, feature, positions)
draft_logits = self.lm_head(self.norm(g))
return draft_logits, g
def draft_to_target(self, draft_ids: torch.Tensor) -> torch.Tensor:
return draft_ids + self.d2t[draft_ids]
@staticmethod
def from_checkpoint(
model_dir: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16
) -> tuple["Eagle3Draft", Eagle3Config]:
import json
with open(os.path.join(model_dir, "config.json")) as f:
cfg = json.load(f)
tlc = cfg["transformer_layer_config"]
config = Eagle3Config(
hidden_size=tlc["hidden_size"],
target_hidden_size=cfg.get("target_hidden_size") or tlc["hidden_size"],
intermediate_size=tlc["intermediate_size"],
num_attention_heads=tlc["num_attention_heads"],
num_key_value_heads=tlc["num_key_value_heads"],
head_dim=tlc["head_dim"],
rope_theta=tlc["rope_parameters"]["rope_theta"],
rms_norm_eps=tlc["rms_norm_eps"],
draft_vocab_size=cfg["draft_vocab_size"],
target_vocab_size=tlc.get("vocab_size", 262144),
aux_hidden_state_layers=cfg["eagle_aux_hidden_state_layer_ids"],
norm_before_residual=cfg.get("norm_before_residual", False),
norm_before_fc=cfg.get("norm_before_fc", False),
)
if config.norm_before_fc:
# This checkpoint variant requires an input RMSNorm before fc.
raise ValueError(
"norm_before_fc=True checkpoints are not supported "
"(would need an input RMSNorm before fc)"
)
raw = _load_safetensors(model_dir)
config.has_own_embed = "embed_tokens.weight" in raw
# Cast checkpoint weights after module construction so inv_freq stays fp32.
model = Eagle3Draft(config)
# The single decoder layer is stored as layers.0.* on disk.
state_dict = {
(k.replace("layers.0.", "midlayer.") if k.startswith("layers.0.") else k): (
v.to(dtype) if v.is_floating_point() else v
)
for k, v in raw.items()
}
# d2t/t2d are index/mask tensors (their checkpoint shape differs from the
# placeholder buffers); register them directly, load the rest strict.
model.register_buffer("d2t", state_dict.pop("d2t"), persistent=False)
model.register_buffer("t2d", state_dict.pop("t2d"), persistent=False)
model.load_state_dict(state_dict, strict=True, assign=True)
model = model.to(device)
assert (
model.midlayer.self_attn.inv_freq.dtype == torch.float32
), "RoPE inv_freq must remain float32"
return model.eval(), config
def _load_safetensors(model_dir: str) -> dict:
"""Load a monolithic or sharded safetensors checkpoint to CPU tensors."""
import json
from safetensors import safe_open
index = os.path.join(model_dir, "model.safetensors.index.json")
mono = os.path.join(model_dir, "model.safetensors")
if os.path.exists(mono):
shards = ["model.safetensors"]
elif os.path.exists(index):
with open(index) as f:
shards = sorted(set(json.load(f)["weight_map"].values()))
else:
raise FileNotFoundError(
f"no model.safetensors or model.safetensors.index.json in {model_dir}"
)
raw = {}
for shard in shards:
with safe_open(
os.path.join(model_dir, shard), framework="pt", device="cpu"
) as f:
for k in f.keys():
if k in raw:
raise ValueError(f"duplicate tensor {k!r} across shards ({shard})")
raw[k] = f.get_tensor(k)
return raw