@@ -30,10 +30,10 @@ def __init__(
3030 self ,
3131 model_config : ModelConfig [LlamaConfig ],
3232 layer_idx : Optional [int ] = None ,
33- eh_proj_before_attn : bool = False ,
33+ next_layer_regular : bool = False ,
3434 ):
3535 config = model_config .pretrained_config
36- self ._eh_proj_before_attn = eh_proj_before_attn
36+ self ._next_layer_regular = next_layer_regular
3737 super ().__init__ (
3838 hidden_size = config .hidden_size ,
3939 num_attention_heads = config .num_attention_heads ,
@@ -54,7 +54,7 @@ def __init__(
5454 tp_size = 1
5555 # Override the QKV projection. The number of input features
5656 # is twice as big for EAGLE3 draft models.
57- if not self ._eh_proj_before_attn :
57+ if not self ._next_layer_regular :
5858 self .qkv_proj = Linear (
5959 2 * self .hidden_size ,
6060 tp_size * self .q_size + 2 * tp_size * self .kv_size ,
@@ -76,13 +76,14 @@ def __init__(
7676 self ,
7777 model_config : LlamaConfig ,
7878 layer_idx : int = 0 ,
79+ is_first_layer : bool = True ,
7980 ) -> Tuple [torch .Tensor , torch .Tensor ]:
8081 super ().__init__ ()
8182 config = model_config .pretrained_config
8283 eagle_config = config .eagle_config if hasattr (config , "eagle_config" ) else {}
8384 self .layer_idx = layer_idx
84- self ._eh_proj_before_attn = eagle_config .get ("eh_proj_before_attn" , False )
85- self .self_attn = Eagle3Attention (model_config , layer_idx , self ._eh_proj_before_attn )
85+ self ._next_layer_regular = ( eagle_config . get ( "next_layer_regular" , True ) and not is_first_layer ) or eagle_config .get ("eh_proj_before_attn" , False )
86+ self .self_attn = Eagle3Attention (model_config , layer_idx , self ._next_layer_regular )
8687
8788 if config .model_type == "llama4_text" :
8889 inter_size = config .intermediate_size_mlp
@@ -98,7 +99,7 @@ def __init__(
9899 overridden_tp_size = 1
99100 if model_config .mapping .enable_attention_dp else None ,
100101 )
101- if not self ._eh_proj_before_attn :
102+ if not self ._next_layer_regular :
102103 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
103104 eps = config .rms_norm_eps ,
104105 dtype = config .torch_dtype )
@@ -122,7 +123,7 @@ def forward(
122123 residual = hidden_states
123124
124125 hidden_states = self .hidden_norm (hidden_states )
125- if not self ._eh_proj_before_attn :
126+ if not self ._next_layer_regular :
126127 embeds = self .input_layernorm (embeds )
127128 hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
128129
@@ -179,7 +180,7 @@ def __init__(
179180
180181 if self .num_layers > 1 :
181182 self .midlayer = nn .ModuleList ([
182- Eagle3DecoderLayer (model_config , start_layer_idx + i )
183+ Eagle3DecoderLayer (model_config , start_layer_idx + i , i == 0 )
183184 for i in range (self .num_layers )
184185 ])
185186 else :
0 commit comments