@@ -52,19 +52,20 @@ def __init__(
5252 tp_size = 1
5353 # Override the QKV projection. The number of input features
5454 # is twice as big for EAGLE3 draft models.
55- self .qkv_proj = Linear (
56- 2 * self .hidden_size ,
57- tp_size * self .q_size + 2 * tp_size * self .kv_size ,
58- bias = config .attention_bias ,
59- dtype = config .torch_dtype ,
60- mapping = self .qkv_proj .mapping ,
61- tensor_parallel_mode = TensorParallelMode .COLUMN ,
62- weights_loading_config = WeightsLoadingConfig (
63- weight_mode = WeightMode .FUSED_QKV_LINEAR ),
64- quant_config = model_config .get_quant_config (),
65- skip_create_weights_in_init = model_config .
66- skip_create_weights_in_init ,
67- )
55+ if not self ._next_layer_regular :
56+ self .qkv_proj = Linear (
57+ 2 * self .hidden_size ,
58+ tp_size * self .q_size + 2 * tp_size * self .kv_size ,
59+ bias = config .attention_bias ,
60+ dtype = config .torch_dtype ,
61+ mapping = self .qkv_proj .mapping ,
62+ tensor_parallel_mode = TensorParallelMode .COLUMN ,
63+ weights_loading_config = WeightsLoadingConfig (
64+ weight_mode = WeightMode .FUSED_QKV_LINEAR ),
65+ quant_config = model_config .get_quant_config (),
66+ skip_create_weights_in_init = model_config .
67+ skip_create_weights_in_init ,
68+ )
6869
6970
7071class Eagle3DecoderLayer (DecoderLayer ):
@@ -73,12 +74,13 @@ def __init__(
7374 self ,
7475 model_config : LlamaConfig ,
7576 layer_idx : int = 0 ,
77+ is_first_layer : bool = True ,
7678 ) -> Tuple [torch .Tensor , torch .Tensor ]:
7779 super ().__init__ ()
7880 config = model_config .pretrained_config
7981 self .layer_idx = layer_idx
80-
81- self .self_attn = Eagle3Attention (model_config , layer_idx )
82+ self . _next_layer_regular = config . eagle_config . get ( "next_layer_regular" , True ) and not is_first_layer
83+ self .self_attn = Eagle3Attention (model_config , layer_idx , self . _next_layer_regular )
8284
8385 if config .model_type == "llama4_text" :
8486 inter_size = config .intermediate_size_mlp
@@ -94,9 +96,10 @@ def __init__(
9496 overridden_tp_size = 1
9597 if model_config .mapping .enable_attention_dp else None ,
9698 )
97- self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
98- eps = config .rms_norm_eps ,
99- dtype = config .torch_dtype )
99+ if not self ._next_layer_regular :
100+ self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
101+ eps = config .rms_norm_eps ,
102+ dtype = config .torch_dtype )
100103
101104 self .hidden_norm = RMSNorm (hidden_size = config .hidden_size ,
102105 eps = config .rms_norm_eps ,
@@ -116,10 +119,10 @@ def forward(
116119 ) -> torch .Tensor :
117120 residual = hidden_states
118121
119- embeds = self .input_layernorm (embeds )
120122 hidden_states = self .hidden_norm (hidden_states )
121-
122- hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
123+ if not self ._next_layer_regular :
124+ embeds = self .input_layernorm (embeds )
125+ hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
123126
124127 hidden_states = self .self_attn (
125128 position_ids = position_ids ,
@@ -160,6 +163,8 @@ def __init__(
160163 self .hidden_size_in = config .target_hidden_size
161164 else :
162165 self .hidden_size_in = config .hidden_size
166+
167+ self ._return_hidden_post_norm = config .eagle_config .get ("return_hidden_post_norm" , False )
163168
164169 if self .spec_config .num_capture_layers > 1 :
165170 self .fc = Linear (self .hidden_size_in *
@@ -170,7 +175,7 @@ def __init__(
170175
171176 if self .num_layers > 1 :
172177 self .midlayer = nn .ModuleList ([
173- Eagle3DecoderLayer (model_config , start_layer_idx + i )
178+ Eagle3DecoderLayer (model_config , start_layer_idx + i , i == 0 )
174179 for i in range (self .num_layers )
175180 ])
176181 else :
@@ -249,6 +254,8 @@ def forward(
249254
250255 hidden_states , hidden_states_to_save = self .norm (
251256 hidden_states , residual )
257+ if self ._return_hidden_post_norm :
258+ return hidden_states , hidden_states
252259 return hidden_states , hidden_states_to_save
253260
254261
0 commit comments