@@ -30,8 +30,10 @@ def __init__(
3030 self ,
3131 model_config : ModelConfig [LlamaConfig ],
3232 layer_idx : Optional [int ] = None ,
33+ eh_proj_before_attn : bool = False ,
3334 ):
3435 config = model_config .pretrained_config
36+ self ._eh_proj_before_attn = eh_proj_before_attn
3537 super ().__init__ (
3638 hidden_size = config .hidden_size ,
3739 num_attention_heads = config .num_attention_heads ,
@@ -52,7 +54,7 @@ def __init__(
5254 tp_size = 1
5355 # Override the QKV projection. The number of input features
5456 # is twice as big for EAGLE3 draft models.
55- if not self ._next_layer_regular :
57+ if not self ._eh_proj_before_attn :
5658 self .qkv_proj = Linear (
5759 2 * self .hidden_size ,
5860 tp_size * self .q_size + 2 * tp_size * self .kv_size ,
@@ -74,13 +76,13 @@ def __init__(
7476 self ,
7577 model_config : LlamaConfig ,
7678 layer_idx : int = 0 ,
77- is_first_layer : bool = True ,
7879 ) -> Tuple [torch .Tensor , torch .Tensor ]:
7980 super ().__init__ ()
8081 config = model_config .pretrained_config
82+ eagle_config = config .eagle_config if hasattr (config , "eagle_config" ) else {}
8183 self .layer_idx = layer_idx
82- self ._next_layer_regular = config . eagle_config .get ("next_layer_regular " , True ) and not is_first_layer
83- self .self_attn = Eagle3Attention (model_config , layer_idx , self ._next_layer_regular )
84+ self ._eh_proj_before_attn = eagle_config .get ("eh_proj_before_attn " , False )
85+ self .self_attn = Eagle3Attention (model_config , layer_idx , self ._eh_proj_before_attn )
8486
8587 if config .model_type == "llama4_text" :
8688 inter_size = config .intermediate_size_mlp
@@ -96,7 +98,7 @@ def __init__(
9698 overridden_tp_size = 1
9799 if model_config .mapping .enable_attention_dp else None ,
98100 )
99- if not self ._next_layer_regular :
101+ if not self ._eh_proj_before_attn :
100102 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
101103 eps = config .rms_norm_eps ,
102104 dtype = config .torch_dtype )
@@ -120,7 +122,7 @@ def forward(
120122 residual = hidden_states
121123
122124 hidden_states = self .hidden_norm (hidden_states )
123- if not self ._next_layer_regular :
125+ if not self ._eh_proj_before_attn :
124126 embeds = self .input_layernorm (embeds )
125127 hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
126128
@@ -153,18 +155,20 @@ def __init__(
153155 super ().__init__ (model_config )
154156
155157 config = model_config .pretrained_config
158+ eagle_config = config .eagle_config if hasattr (config , "eagle_config" ) else {}
156159 self .spec_config = model_config .spec_config
157160 self .dtype = config .torch_dtype
158161 self .hidden_size = config .hidden_size
159162 self .mapping = model_config .mapping
160163 self .num_layers = model_config .pretrained_config .num_hidden_layers
164+ self ._eh_proj_before_attn = eagle_config .get ("eh_proj_before_attn" , False )
161165
162166 if hasattr (config , "target_hidden_size" ):
163167 self .hidden_size_in = config .target_hidden_size
164168 else :
165169 self .hidden_size_in = config .hidden_size
166170
167- self ._return_hidden_post_norm = config . eagle_config .get ("return_hidden_post_norm" , False )
171+ self ._return_hidden_post_norm = eagle_config .get ("return_hidden_post_norm" , False )
168172
169173 if self .spec_config .num_capture_layers > 1 :
170174 self .fc = Linear (self .hidden_size_in *
@@ -175,7 +179,7 @@ def __init__(
175179
176180 if self .num_layers > 1 :
177181 self .midlayer = nn .ModuleList ([
178- Eagle3DecoderLayer (model_config , start_layer_idx + i , i == 0 )
182+ Eagle3DecoderLayer (model_config , start_layer_idx + i )
179183 for i in range (self .num_layers )
180184 ])
181185 else :
@@ -189,6 +193,14 @@ def __init__(
189193 self .d2t = nn .Parameter (torch .empty ((config .draft_vocab_size , ),
190194 dtype = torch .int32 ),
191195 requires_grad = False )
196+ if self ._eh_proj_before_attn :
197+ self .enorm = RMSNorm (hidden_size = config .hidden_size ,
198+ eps = config .rms_norm_eps ,
199+ dtype = config .torch_dtype )
200+ self .eh_proj = nn .Linear (config .hidden_size * 2 ,
201+ config .hidden_size ,
202+ bias = eagle_config .get ("eh_proj_bias" , False ),
203+ dtype = config .torch_dtype )
192204
193205 if self .hidden_size_in != config .hidden_size :
194206 if model_config .mapping .enable_attention_dp :
@@ -230,11 +242,15 @@ def forward(
230242 inputs_embeds = self .embed_tokens (input_ids ).to (self .dtype )
231243
232244 assert hidden_states is not None
233-
234245 # NOTE: If hidden states from the target model have to be concatenated,
235- # we expect that to happen outside the model definition. This helps us
236- # avoid data-dependent control flow and gives us better CUDA graph
237- # coverage.
246+ # ideally,we expect that to happen outside the model definition. This
247+ # helps usavoid data-dependent control flow and gives us better CUDA
248+ # graph coverage.
249+ if self ._eh_proj_before_attn :
250+ input_embeds = self .enorm (inputs_embeds )
251+ hidden_states = torch .cat ([input_embeds , hidden_states ], dim = - 1 )
252+ hidden_states = self .eh_proj (hidden_states )
253+
238254 residual = None
239255 if self .num_layers > 1 :
240256 for layer in self .midlayer :
0 commit comments