Skip to content

Commit 1047091

Browse files
Dogacellaikhtewari
andauthored
[None][feat] Eagle 3.1 -- Support post-norm and per-aux fc_norm for Eagle3 draft models (#14988)
Signed-off-by: Doğaç Eldenk <dogacel@gmail.com> Co-authored-by: Laikh Tewari <ltewari@nvidia.com>
1 parent 7193f41 commit 1047091

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def __init__(
307307
self.hidden_size_in = config.hidden_size
308308

309309
self._return_hidden_post_norm = eagle_config.get(
310-
"return_hidden_post_norm", False)
310+
"return_hidden_post_norm", False) or getattr(
311+
config, "norm_output", False)
311312

312313
# Create auxiliary CUDA stream for MLA operations (only needed for MLA)
313314
self.aux_stream = torch.cuda.Stream() if use_mla else None
@@ -330,6 +331,18 @@ def __init__(
330331
else:
331332
self.input_norm = None
332333

334+
self._use_fc_norm = getattr(config, "fc_norm", False)
335+
if self._use_fc_norm:
336+
self.fc_norm = nn.ModuleList([
337+
RMSNorm(
338+
hidden_size=self.hidden_size_in,
339+
eps=config.rms_norm_eps,
340+
dtype=config.torch_dtype,
341+
) for _ in range(self.spec_config.num_capture_layers)
342+
])
343+
else:
344+
self.fc_norm = None
345+
333346
if self.num_layers > 1:
334347
self.midlayer = nn.ModuleList([
335348
Eagle3DecoderLayer(
@@ -590,7 +603,14 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
590603

591604
expected_hidden_size = self.model.hidden_size
592605
if hidden_states.shape[-1] != expected_hidden_size:
593-
if self.model._norm_before_fc:
606+
if self.model.fc_norm is not None:
607+
chunks = hidden_states.chunk(len(self.model.fc_norm), dim=-1)
608+
hidden_states = torch.cat([
609+
norm(chunk)
610+
for norm, chunk in zip(self.model.fc_norm, chunks)
611+
],
612+
dim=-1)
613+
elif self.model._norm_before_fc:
594614
hidden_states = self.model.input_norm(hidden_states)
595615
hidden_states = self.model.fc(hidden_states)
596616

0 commit comments

Comments
 (0)