Skip to content

Commit 8d614a8

Browse files
author
Izzy Putterman
authored
Merge pull request #1 from IzzyPutterman/mashkenazi/eagle-options-eh-proj
added eh proj option to eagle draft model
2 parents 12bc2e2 + 333b8b9 commit 8d614a8

1 file changed

Lines changed: 23 additions & 6 deletions

File tree

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def __init__(
3030
self,
3131
model_config: ModelConfig[LlamaConfig],
3232
layer_idx: Optional[int] = None,
33+
next_layer_regular: bool = False,
3334
):
3435
config = model_config.pretrained_config
36+
self._next_layer_regular = next_layer_regular
3537
super().__init__(
3638
hidden_size=config.hidden_size,
3739
num_attention_heads=config.num_attention_heads,
@@ -78,8 +80,9 @@ def __init__(
7880
) -> Tuple[torch.Tensor, torch.Tensor]:
7981
super().__init__()
8082
config = model_config.pretrained_config
83+
eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {}
8184
self.layer_idx = layer_idx
82-
self._next_layer_regular = config.eagle_config.get("next_layer_regular", True) and not is_first_layer
85+
self._next_layer_regular = (eagle_config.get("next_layer_regular", True) and not is_first_layer) or eagle_config.get("eh_proj_before_attn", False)
8386
self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular)
8487

8588
if config.model_type == "llama4_text":
@@ -153,18 +156,20 @@ def __init__(
153156
super().__init__(model_config)
154157

155158
config = model_config.pretrained_config
159+
eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {}
156160
self.spec_config = model_config.spec_config
157161
self.dtype = config.torch_dtype
158162
self.hidden_size = config.hidden_size
159163
self.mapping = model_config.mapping
160164
self.num_layers = model_config.pretrained_config.num_hidden_layers
165+
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False)
161166

162167
if hasattr(config, "target_hidden_size"):
163168
self.hidden_size_in = config.target_hidden_size
164169
else:
165170
self.hidden_size_in = config.hidden_size
166171

167-
self._return_hidden_post_norm = config.eagle_config.get("return_hidden_post_norm", False)
172+
self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False)
168173

169174
if self.spec_config.num_capture_layers > 1:
170175
self.fc = Linear(self.hidden_size_in *
@@ -189,6 +194,14 @@ def __init__(
189194
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
190195
dtype=torch.int32),
191196
requires_grad=False)
197+
if self._eh_proj_before_attn:
198+
self.enorm = RMSNorm(hidden_size=config.hidden_size,
199+
eps=config.rms_norm_eps,
200+
dtype=config.torch_dtype)
201+
self.eh_proj = nn.Linear(config.hidden_size * 2,
202+
config.hidden_size,
203+
bias=eagle_config.get("eh_proj_bias", False),
204+
dtype=config.torch_dtype)
192205

193206
if self.hidden_size_in != config.hidden_size:
194207
if model_config.mapping.enable_attention_dp:
@@ -230,11 +243,15 @@ def forward(
230243
inputs_embeds = self.embed_tokens(input_ids).to(self.dtype)
231244

232245
assert hidden_states is not None
233-
234246
# NOTE: If hidden states from the target model have to be concatenated,
235-
# we expect that to happen outside the model definition. This helps us
236-
# avoid data-dependent control flow and gives us better CUDA graph
237-
# coverage.
247+
# ideally,we expect that to happen outside the model definition. This
248+
# helps usavoid data-dependent control flow and gives us better CUDA
249+
# graph coverage.
250+
if self._eh_proj_before_attn:
251+
input_embeds = self.enorm(inputs_embeds)
252+
hidden_states = torch.cat([input_embeds, hidden_states], dim=-1)
253+
hidden_states = self.eh_proj(hidden_states)
254+
238255
residual = None
239256
if self.num_layers > 1:
240257
for layer in self.midlayer:

0 commit comments

Comments
 (0)