Skip to content

Commit b45647d

Browse files
committed
added eh proj option to eagle draft model
1 parent 12bc2e2 commit b45647d

1 file changed

Lines changed: 28 additions & 12 deletions

File tree

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def __init__(
3030
self,
3131
model_config: ModelConfig[LlamaConfig],
3232
layer_idx: Optional[int] = None,
33+
eh_proj_before_attn: bool = False,
3334
):
3435
config = model_config.pretrained_config
36+
self._eh_proj_before_attn = eh_proj_before_attn
3537
super().__init__(
3638
hidden_size=config.hidden_size,
3739
num_attention_heads=config.num_attention_heads,
@@ -52,7 +54,7 @@ def __init__(
5254
tp_size = 1
5355
# Override the QKV projection. The number of input features
5456
# is twice as big for EAGLE3 draft models.
55-
if not self._next_layer_regular:
57+
if not self._eh_proj_before_attn:
5658
self.qkv_proj = Linear(
5759
2 * self.hidden_size,
5860
tp_size * self.q_size + 2 * tp_size * self.kv_size,
@@ -74,13 +76,13 @@ def __init__(
7476
self,
7577
model_config: LlamaConfig,
7678
layer_idx: int = 0,
77-
is_first_layer: bool = True,
7879
) -> Tuple[torch.Tensor, torch.Tensor]:
7980
super().__init__()
8081
config = model_config.pretrained_config
82+
eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {}
8183
self.layer_idx = layer_idx
82-
self._next_layer_regular = config.eagle_config.get("next_layer_regular", True) and not is_first_layer
83-
self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular)
84+
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False)
85+
self.self_attn = Eagle3Attention(model_config, layer_idx, self._eh_proj_before_attn)
8486

8587
if config.model_type == "llama4_text":
8688
inter_size = config.intermediate_size_mlp
@@ -96,7 +98,7 @@ def __init__(
9698
overridden_tp_size=1
9799
if model_config.mapping.enable_attention_dp else None,
98100
)
99-
if not self._next_layer_regular:
101+
if not self._eh_proj_before_attn:
100102
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
101103
eps=config.rms_norm_eps,
102104
dtype=config.torch_dtype)
@@ -120,7 +122,7 @@ def forward(
120122
residual = hidden_states
121123

122124
hidden_states = self.hidden_norm(hidden_states)
123-
if not self._next_layer_regular:
125+
if not self._eh_proj_before_attn:
124126
embeds = self.input_layernorm(embeds)
125127
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
126128

@@ -153,18 +155,20 @@ def __init__(
153155
super().__init__(model_config)
154156

155157
config = model_config.pretrained_config
158+
eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {}
156159
self.spec_config = model_config.spec_config
157160
self.dtype = config.torch_dtype
158161
self.hidden_size = config.hidden_size
159162
self.mapping = model_config.mapping
160163
self.num_layers = model_config.pretrained_config.num_hidden_layers
164+
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False)
161165

162166
if hasattr(config, "target_hidden_size"):
163167
self.hidden_size_in = config.target_hidden_size
164168
else:
165169
self.hidden_size_in = config.hidden_size
166170

167-
self._return_hidden_post_norm = config.eagle_config.get("return_hidden_post_norm", False)
171+
self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False)
168172

169173
if self.spec_config.num_capture_layers > 1:
170174
self.fc = Linear(self.hidden_size_in *
@@ -175,7 +179,7 @@ def __init__(
175179

176180
if self.num_layers > 1:
177181
self.midlayer = nn.ModuleList([
178-
Eagle3DecoderLayer(model_config, start_layer_idx + i, i == 0)
182+
Eagle3DecoderLayer(model_config, start_layer_idx + i)
179183
for i in range(self.num_layers)
180184
])
181185
else:
@@ -189,6 +193,14 @@ def __init__(
189193
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
190194
dtype=torch.int32),
191195
requires_grad=False)
196+
if self._eh_proj_before_attn:
197+
self.enorm = RMSNorm(hidden_size=config.hidden_size,
198+
eps=config.rms_norm_eps,
199+
dtype=config.torch_dtype)
200+
self.eh_proj = nn.Linear(config.hidden_size * 2,
201+
config.hidden_size,
202+
bias=eagle_config.get("eh_proj_bias", False),
203+
dtype=config.torch_dtype)
192204

193205
if self.hidden_size_in != config.hidden_size:
194206
if model_config.mapping.enable_attention_dp:
@@ -230,11 +242,15 @@ def forward(
230242
inputs_embeds = self.embed_tokens(input_ids).to(self.dtype)
231243

232244
assert hidden_states is not None
233-
234245
# NOTE: If hidden states from the target model have to be concatenated,
235-
# we expect that to happen outside the model definition. This helps us
236-
# avoid data-dependent control flow and gives us better CUDA graph
237-
# coverage.
246+
# ideally,we expect that to happen outside the model definition. This
247+
# helps usavoid data-dependent control flow and gives us better CUDA
248+
# graph coverage.
249+
if self._eh_proj_before_attn:
250+
input_embeds = self.enorm(inputs_embeds)
251+
hidden_states = torch.cat([input_embeds, hidden_states], dim=-1)
252+
hidden_states = self.eh_proj(hidden_states)
253+
238254
residual = None
239255
if self.num_layers > 1:
240256
for layer in self.midlayer:

0 commit comments

Comments
 (0)