@@ -661,6 +661,7 @@ def __init__(self,
661661 device : torch .device | str | None ):
662662 super ().__init__ ()
663663 self .norm_eps = args .norm_eps
664+ self .layer_id = layer_id
664665 self .attn = Attention (config , layer_id , args , dtype = dtype , device = device )
665666 self .ffn = MoE (config , layer_id , args , dtype = dtype , device = device )
666667 self .attn_norm = RMSNorm (args .dim , args .norm_eps , dtype = dtype , device = device )
@@ -986,10 +987,23 @@ def _load_expert(self, name: str, weight: torch.Tensor, params_dict: dict[str, n
986987 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
987988 params_dict = dict (self .named_parameters ())
988989
990+ def __skip_layers ():
991+ """We might change the number of layers so we can debug the model
992+ with less gpus."""
993+ import re
994+ matches = re .findall (r'layers\.(\d+)\.' , name )
995+ if not matches :
996+ return False
997+ layer_id = int (matches [0 ])
998+ return layer_id >= self .config .num_hidden_layers
999+
9891000 for name , loaded_weight in weights :
9901001 if name .startswith ('mtp.' ):
9911002 continue
9921003
1004+ if __skip_layers ():
1005+ continue
1006+
9931007 if name .endswith ('tie2eid' ):
9941008 name = name .replace ('tie2eid' , 'tid2eid' )
9951009 if '.ffn.' in name :
0 commit comments