@@ -307,7 +307,8 @@ def __init__(
307307 self .hidden_size_in = config .hidden_size
308308
309309 self ._return_hidden_post_norm = eagle_config .get (
310- "return_hidden_post_norm" , False )
310+ "return_hidden_post_norm" , False ) or getattr (
311+ config , "norm_output" , False )
311312
312313 # Create auxiliary CUDA stream for MLA operations (only needed for MLA)
313314 self .aux_stream = torch .cuda .Stream () if use_mla else None
@@ -330,6 +331,18 @@ def __init__(
330331 else :
331332 self .input_norm = None
332333
334+ self ._use_fc_norm = getattr (config , "fc_norm" , False )
335+ if self ._use_fc_norm :
336+ self .fc_norm = nn .ModuleList ([
337+ RMSNorm (
338+ hidden_size = self .hidden_size_in ,
339+ eps = config .rms_norm_eps ,
340+ dtype = config .torch_dtype ,
341+ ) for _ in range (self .spec_config .num_capture_layers )
342+ ])
343+ else :
344+ self .fc_norm = None
345+
333346 if self .num_layers > 1 :
334347 self .midlayer = nn .ModuleList ([
335348 Eagle3DecoderLayer (
@@ -590,7 +603,14 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
590603
591604 expected_hidden_size = self .model .hidden_size
592605 if hidden_states .shape [- 1 ] != expected_hidden_size :
593- if self .model ._norm_before_fc :
606+ if self .model .fc_norm is not None :
607+ chunks = hidden_states .chunk (len (self .model .fc_norm ), dim = - 1 )
608+ hidden_states = torch .cat ([
609+ norm (chunk )
610+ for norm , chunk in zip (self .model .fc_norm , chunks )
611+ ],
612+ dim = - 1 )
613+ elif self .model ._norm_before_fc :
594614 hidden_states = self .model .input_norm (hidden_states )
595615 hidden_states = self .model .fc (hidden_states )
596616
0 commit comments