@@ -30,8 +30,10 @@ def __init__(
3030 self ,
3131 model_config : ModelConfig [LlamaConfig ],
3232 layer_idx : Optional [int ] = None ,
33+ next_layer_regular : bool = False ,
3334 ):
3435 config = model_config .pretrained_config
36+ self ._next_layer_regular = next_layer_regular
3537 super ().__init__ (
3638 hidden_size = config .hidden_size ,
3739 num_attention_heads = config .num_attention_heads ,
@@ -78,8 +80,9 @@ def __init__(
7880 ) -> Tuple [torch .Tensor , torch .Tensor ]:
7981 super ().__init__ ()
8082 config = model_config .pretrained_config
83+ eagle_config = config .eagle_config if hasattr (config , "eagle_config" ) else {}
8184 self .layer_idx = layer_idx
82- self ._next_layer_regular = config . eagle_config .get ("next_layer_regular" , True ) and not is_first_layer
85+ self ._next_layer_regular = ( eagle_config .get ("next_layer_regular" , True ) and not is_first_layer ) or eagle_config . get ( "eh_proj_before_attn" , False )
8386 self .self_attn = Eagle3Attention (model_config , layer_idx , self ._next_layer_regular )
8487
8588 if config .model_type == "llama4_text" :
@@ -153,18 +156,20 @@ def __init__(
153156 super ().__init__ (model_config )
154157
155158 config = model_config .pretrained_config
159+ eagle_config = config .eagle_config if hasattr (config , "eagle_config" ) else {}
156160 self .spec_config = model_config .spec_config
157161 self .dtype = config .torch_dtype
158162 self .hidden_size = config .hidden_size
159163 self .mapping = model_config .mapping
160164 self .num_layers = model_config .pretrained_config .num_hidden_layers
165+ self ._eh_proj_before_attn = eagle_config .get ("eh_proj_before_attn" , False )
161166
162167 if hasattr (config , "target_hidden_size" ):
163168 self .hidden_size_in = config .target_hidden_size
164169 else :
165170 self .hidden_size_in = config .hidden_size
166171
167- self ._return_hidden_post_norm = config . eagle_config .get ("return_hidden_post_norm" , False )
172+ self ._return_hidden_post_norm = eagle_config .get ("return_hidden_post_norm" , False )
168173
169174 if self .spec_config .num_capture_layers > 1 :
170175 self .fc = Linear (self .hidden_size_in *
@@ -189,6 +194,14 @@ def __init__(
189194 self .d2t = nn .Parameter (torch .empty ((config .draft_vocab_size , ),
190195 dtype = torch .int32 ),
191196 requires_grad = False )
197+ if self ._eh_proj_before_attn :
198+ self .enorm = RMSNorm (hidden_size = config .hidden_size ,
199+ eps = config .rms_norm_eps ,
200+ dtype = config .torch_dtype )
201+ self .eh_proj = nn .Linear (config .hidden_size * 2 ,
202+ config .hidden_size ,
203+ bias = eagle_config .get ("eh_proj_bias" , False ),
204+ dtype = config .torch_dtype )
192205
193206 if self .hidden_size_in != config .hidden_size :
194207 if model_config .mapping .enable_attention_dp :
@@ -230,11 +243,15 @@ def forward(
230243 inputs_embeds = self .embed_tokens (input_ids ).to (self .dtype )
231244
232245 assert hidden_states is not None
233-
234246 # NOTE: If hidden states from the target model have to be concatenated,
235- # we expect that to happen outside the model definition. This helps us
236- # avoid data-dependent control flow and gives us better CUDA graph
237- # coverage.
247+ # ideally,we expect that to happen outside the model definition. This
248+ # helps usavoid data-dependent control flow and gives us better CUDA
249+ # graph coverage.
250+ if self ._eh_proj_before_attn :
251+ input_embeds = self .enorm (inputs_embeds )
252+ hidden_states = torch .cat ([input_embeds , hidden_states ], dim = - 1 )
253+ hidden_states = self .eh_proj (hidden_states )
254+
238255 residual = None
239256 if self .num_layers > 1 :
240257 for layer in self .midlayer :
0 commit comments