Skip to content

Commit 333b8b9

Browse files
committed
added back next_layer_regular for multiple layer eagle
1 parent b45647d commit 333b8b9

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def __init__(
3030
self,
3131
model_config: ModelConfig[LlamaConfig],
3232
layer_idx: Optional[int] = None,
33-
eh_proj_before_attn: bool = False,
33+
next_layer_regular: bool = False,
3434
):
3535
config = model_config.pretrained_config
36-
self._eh_proj_before_attn = eh_proj_before_attn
36+
self._next_layer_regular = next_layer_regular
3737
super().__init__(
3838
hidden_size=config.hidden_size,
3939
num_attention_heads=config.num_attention_heads,
@@ -54,7 +54,7 @@ def __init__(
5454
tp_size = 1
5555
# Override the QKV projection. The number of input features
5656
# is twice as big for EAGLE3 draft models.
57-
if not self._eh_proj_before_attn:
57+
if not self._next_layer_regular:
5858
self.qkv_proj = Linear(
5959
2 * self.hidden_size,
6060
tp_size * self.q_size + 2 * tp_size * self.kv_size,
@@ -76,13 +76,14 @@ def __init__(
7676
self,
7777
model_config: LlamaConfig,
7878
layer_idx: int = 0,
79+
is_first_layer: bool = True,
7980
) -> Tuple[torch.Tensor, torch.Tensor]:
8081
super().__init__()
8182
config = model_config.pretrained_config
8283
eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {}
8384
self.layer_idx = layer_idx
84-
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False)
85-
self.self_attn = Eagle3Attention(model_config, layer_idx, self._eh_proj_before_attn)
85+
self._next_layer_regular = (eagle_config.get("next_layer_regular", True) and not is_first_layer) or eagle_config.get("eh_proj_before_attn", False)
86+
self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular)
8687

8788
if config.model_type == "llama4_text":
8889
inter_size = config.intermediate_size_mlp
@@ -98,7 +99,7 @@ def __init__(
9899
overridden_tp_size=1
99100
if model_config.mapping.enable_attention_dp else None,
100101
)
101-
if not self._eh_proj_before_attn:
102+
if not self._next_layer_regular:
102103
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
103104
eps=config.rms_norm_eps,
104105
dtype=config.torch_dtype)
@@ -122,7 +123,7 @@ def forward(
122123
residual = hidden_states
123124

124125
hidden_states = self.hidden_norm(hidden_states)
125-
if not self._eh_proj_before_attn:
126+
if not self._next_layer_regular:
126127
embeds = self.input_layernorm(embeds)
127128
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
128129

@@ -179,7 +180,7 @@ def __init__(
179180

180181
if self.num_layers > 1:
181182
self.midlayer = nn.ModuleList([
182-
Eagle3DecoderLayer(model_config, start_layer_idx + i)
183+
Eagle3DecoderLayer(model_config, start_layer_idx + i, i == 0)
183184
for i in range(self.num_layers)
184185
])
185186
else:

0 commit comments

Comments
 (0)