@@ -8,12 +8,15 @@ class Falcon(BaseModel):
88 def __init__ (self , config , device_map = None , use_cache = False ):
99 super ().__init__ (config , device_map , use_cache )
1010
11+ def _is_new_decoder_architecture (self ):
12+ return getattr (self .model_config , 'new_decoder_architecture' , False )
13+
1114 def find_blocks (self ):
1215 self .blocks = self .model .transformer .h
1316
1417 def find_embed_layers (self ):
1518 self .word_embeddings = self .model .transformer .word_embeddings
16- self .rotary_emb = self .model .model .rotary_emb
19+ self .rotary_emb = self .model .transformer .rotary_emb
1720
1821 def find_block_name (self ):
1922 self .block_name_prefix = 'model.transformer.h'
@@ -25,30 +28,31 @@ def get_attention_rotary_layers(self):
2528 return [self .rotary_emb ]
2629
2730 def get_layers_except_blocks (self ):
28- return [self .word_embeddings , self .rotary_emb , self .model .transformer .ln_f ]
31+ return [self .word_embeddings , self .rotary_emb , self .model .transformer .ln_f ,
32+ self .model .lm_head ]
33+
34+ def skip_layer_name (self ):
35+ return ['lm_head' ]
2936
3037 def has_bias (self ):
31- return False
38+ return getattr ( self . model_config , 'bias' , False )
3239
3340 def get_layernorms_in_block (self , block ):
34- if block .config .architectures [0 ] == 'RWForCausalLM' :
35- new_decoder_architecture = False
36- elif block .config .architectures [0 ] == 'FalconForCausalLM' :
37- new_decoder_architecture = True
38- if new_decoder_architecture :
41+ if self ._is_new_decoder_architecture ():
3942 return {'ln_attn' : block .ln_attn , 'ln_mlp' : block .ln_mlp }
4043 else :
41- if block .config . parallel_attn :
44+ if getattr ( block .config , ' parallel_attn' , False ) :
4245 return {'input_layernorm' : block .input_layernorm }
4346 else :
44- return {'post_attention_layernorm' : block .post_attention_layernorm }
47+ return {
48+ 'input_layernorm' : block .input_layernorm ,
49+ 'post_attention_layernorm' : block .post_attention_layernorm ,
50+ }
4551
4652 def get_subsets_in_block (self , block ):
47- if block .config .architectures [0 ] == 'RWForCausalLM' :
48- new_decoder_architecture = False
49- elif block .config .architectures [0 ] == 'FalconForCausalLM' :
50- new_decoder_architecture = True
51- if new_decoder_architecture :
53+ new_arch = self ._is_new_decoder_architecture ()
54+
55+ if new_arch :
5256 subset1 = {
5357 'layers' : {
5458 'self_attention.query_key_value' : (
@@ -79,7 +83,7 @@ def get_subsets_in_block(self, block):
7983 'inspect' : block .self_attention .query_key_value ,
8084 'has_kwargs' : False ,
8185 }
82- if block .config . parallel_attn :
86+ if getattr ( block .config , ' parallel_attn' , False ) :
8387 subset3 = {
8488 'layers' : {'mlp.dense_h_to_4h' : block .mlp .dense_h_to_4h },
8589 'prev_op' : [block .input_layernorm ],
0 commit comments