Skip to content

Commit 12bc2e2

Browse files
author
Izzy Putterman
committed
PostNorm and multilayer options
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
1 parent 0e36484 commit 12bc2e2

1 file changed

Lines changed: 29 additions & 22 deletions

File tree

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,20 @@ def __init__(
5252
tp_size = 1
5353
# Override the QKV projection. The number of input features
5454
# is twice as big for EAGLE3 draft models.
55-
self.qkv_proj = Linear(
56-
2 * self.hidden_size,
57-
tp_size * self.q_size + 2 * tp_size * self.kv_size,
58-
bias=config.attention_bias,
59-
dtype=config.torch_dtype,
60-
mapping=self.qkv_proj.mapping,
61-
tensor_parallel_mode=TensorParallelMode.COLUMN,
62-
weights_loading_config=WeightsLoadingConfig(
63-
weight_mode=WeightMode.FUSED_QKV_LINEAR),
64-
quant_config=model_config.get_quant_config(),
65-
skip_create_weights_in_init=model_config.
66-
skip_create_weights_in_init,
67-
)
55+
if not self._next_layer_regular:
56+
self.qkv_proj = Linear(
57+
2 * self.hidden_size,
58+
tp_size * self.q_size + 2 * tp_size * self.kv_size,
59+
bias=config.attention_bias,
60+
dtype=config.torch_dtype,
61+
mapping=self.qkv_proj.mapping,
62+
tensor_parallel_mode=TensorParallelMode.COLUMN,
63+
weights_loading_config=WeightsLoadingConfig(
64+
weight_mode=WeightMode.FUSED_QKV_LINEAR),
65+
quant_config=model_config.get_quant_config(),
66+
skip_create_weights_in_init=model_config.
67+
skip_create_weights_in_init,
68+
)
6869

6970

7071
class Eagle3DecoderLayer(DecoderLayer):
@@ -73,12 +74,13 @@ def __init__(
7374
self,
7475
model_config: LlamaConfig,
7576
layer_idx: int = 0,
77+
is_first_layer: bool = True,
7678
) -> Tuple[torch.Tensor, torch.Tensor]:
7779
super().__init__()
7880
config = model_config.pretrained_config
7981
self.layer_idx = layer_idx
80-
81-
self.self_attn = Eagle3Attention(model_config, layer_idx)
82+
self._next_layer_regular = config.eagle_config.get("next_layer_regular", True) and not is_first_layer
83+
self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular)
8284

8385
if config.model_type == "llama4_text":
8486
inter_size = config.intermediate_size_mlp
@@ -94,9 +96,10 @@ def __init__(
9496
overridden_tp_size=1
9597
if model_config.mapping.enable_attention_dp else None,
9698
)
97-
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
98-
eps=config.rms_norm_eps,
99-
dtype=config.torch_dtype)
99+
if not self._next_layer_regular:
100+
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
101+
eps=config.rms_norm_eps,
102+
dtype=config.torch_dtype)
100103

101104
self.hidden_norm = RMSNorm(hidden_size=config.hidden_size,
102105
eps=config.rms_norm_eps,
@@ -116,10 +119,10 @@ def forward(
116119
) -> torch.Tensor:
117120
residual = hidden_states
118121

119-
embeds = self.input_layernorm(embeds)
120122
hidden_states = self.hidden_norm(hidden_states)
121-
122-
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
123+
if not self._next_layer_regular:
124+
embeds = self.input_layernorm(embeds)
125+
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
123126

124127
hidden_states = self.self_attn(
125128
position_ids=position_ids,
@@ -160,6 +163,8 @@ def __init__(
160163
self.hidden_size_in = config.target_hidden_size
161164
else:
162165
self.hidden_size_in = config.hidden_size
166+
167+
self._return_hidden_post_norm = config.eagle_config.get("return_hidden_post_norm", False)
163168

164169
if self.spec_config.num_capture_layers > 1:
165170
self.fc = Linear(self.hidden_size_in *
@@ -170,7 +175,7 @@ def __init__(
170175

171176
if self.num_layers > 1:
172177
self.midlayer = nn.ModuleList([
173-
Eagle3DecoderLayer(model_config, start_layer_idx + i)
178+
Eagle3DecoderLayer(model_config, start_layer_idx + i, i == 0)
174179
for i in range(self.num_layers)
175180
])
176181
else:
@@ -249,6 +254,8 @@ def forward(
249254

250255
hidden_states, hidden_states_to_save = self.norm(
251256
hidden_states, residual)
257+
if self._return_hidden_post_norm:
258+
return hidden_states, hidden_states
252259
return hidden_states, hidden_states_to_save
253260

254261

0 commit comments

Comments
 (0)