Skip to content

Commit d2145ae

Browse files
authored
Merge pull request #450 from Medhatt21/fix/falcon-model-and-pile-dataset
fix: complete Falcon model support and add Pile dataset
2 parents 04eeca1 + cb8c146 commit d2145ae

File tree

4 files changed

+47
-17
lines changed

4 files changed

+47
-17
lines changed

llmc/data/dataset/base_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, tokenizer, calib_cfg, batch_process=None):
3838
self.seed = calib_cfg['seed']
3939
self.calib_dataset_field_map = {
4040
'pileval': 'text',
41+
'pile': 'text',
4142
'c4': 'text',
4243
'wikitext2': 'text',
4344
'ptb': 'sentence',
@@ -66,6 +67,10 @@ def build_calib_dataset(self):
6667
self.calib_dataset = load_dataset(
6768
'ptb_text_only', 'penn_treebank', split='train'
6869
)
70+
elif self.calib_dataset_name == 'pile':
71+
self.calib_dataset = load_dataset(
72+
'mit-han-lab/pile-val-backup', split='validation'
73+
)
6974
elif self.calib_dataset_name == 'ultrachat':
7075
self.calib_dataset = load_dataset(
7176
'HuggingFaceH4/ultrachat_200k', split='train_sft'

llmc/data/dataset/specified_preproc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def c4_gptq(calib_dataset, tokenizer, n_samples, seq_len):
4747
return samples
4848

4949

50+
@PREPROC_REGISTRY
51+
def pile_gptq(calib_dataset, tokenizer, n_samples, seq_len):
52+
trainenc = tokenizer('\n\n'.join(calib_dataset['text'][:1000]), return_tensors='pt')
53+
samples = []
54+
for _ in range(n_samples):
55+
i = random.randint(0, trainenc.input_ids.shape[1] - seq_len - 1)
56+
j = i + seq_len
57+
inp = trainenc.input_ids[:, i:j]
58+
samples.append(inp)
59+
return samples
60+
61+
5062
@PREPROC_REGISTRY
5163
def pileval_awq(calib_dataset, tokenizer, n_samples, seq_len):
5264
dataset = calib_dataset.shuffle(seed=42)

llmc/eval/eval_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ def __init__(self, model, config):
2525
'wikitext2',
2626
'c4',
2727
'ptb',
28+
'pile',
2829
'custom',
2930
'human_eval',
3031
'mme',
3132
'custom_ppl',
3233
'custom_gen',
3334
't2v',
3435
'i2v',
35-
], f'Not support {self.dataset} dataset now.'
36+
], f'Not support {self.eval_dataset_name} dataset now.'
3637
self.seq_len = self.eval_cfg.get('seq_len', None)
3738
self.num_samples = self.eval_cfg.get('num_samples', None)
3839
self.num_eval_tokens = self.eval_cfg.get('num_eval_tokens', None)
@@ -67,6 +68,10 @@ def build_data(self):
6768
testdata = load_dataset(
6869
'ptb_text_only', 'penn_treebank', split='test'
6970
)
71+
elif self.eval_dataset_name == 'pile':
72+
testdata = load_dataset(
73+
'mit-han-lab/pile-val-backup', split='validation'
74+
)
7075
else:
7176
if self.eval_dataset_name in ['custom_gen', 'custom_ppl', 't2v', 'i2v']:
7277
testdata = self.get_cutomdata(self.eval_dataset_path)
@@ -91,6 +96,10 @@ def build_data(self):
9196
testenc = self.tokenizer(
9297
' '.join(testdata['sentence']), return_tensors='pt'
9398
)
99+
elif self.eval_dataset_name == 'pile':
100+
testenc = self.tokenizer(
101+
'\n\n'.join(testdata['text'][:1000]), return_tensors='pt'
102+
)
94103
elif self.eval_dataset_name == 'custom_ppl':
95104
testenc = self.tokenizer(
96105
'\n'.join([data['question'] + data['answer'] if 'answer' in data else data['question'] for data in testdata]), # noqa

llmc/models/falcon.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ class Falcon(BaseModel):
88
def __init__(self, config, device_map=None, use_cache=False):
99
super().__init__(config, device_map, use_cache)
1010

11+
def _is_new_decoder_architecture(self):
12+
return getattr(self.model_config, 'new_decoder_architecture', False)
13+
1114
def find_blocks(self):
1215
self.blocks = self.model.transformer.h
1316

1417
def find_embed_layers(self):
1518
self.word_embeddings = self.model.transformer.word_embeddings
16-
self.rotary_emb = self.model.model.rotary_emb
19+
self.rotary_emb = self.model.transformer.rotary_emb
1720

1821
def find_block_name(self):
1922
self.block_name_prefix = 'model.transformer.h'
@@ -25,30 +28,31 @@ def get_attention_rotary_layers(self):
2528
return [self.rotary_emb]
2629

2730
def get_layers_except_blocks(self):
28-
return [self.word_embeddings, self.rotary_emb, self.model.transformer.ln_f]
31+
return [self.word_embeddings, self.rotary_emb, self.model.transformer.ln_f,
32+
self.model.lm_head]
33+
34+
def skip_layer_name(self):
35+
return ['lm_head']
2936

3037
def has_bias(self):
31-
return False
38+
return getattr(self.model_config, 'bias', False)
3239

3340
def get_layernorms_in_block(self, block):
34-
if block.config.architectures[0] == 'RWForCausalLM':
35-
new_decoder_architecture = False
36-
elif block.config.architectures[0] == 'FalconForCausalLM':
37-
new_decoder_architecture = True
38-
if new_decoder_architecture:
41+
if self._is_new_decoder_architecture():
3942
return {'ln_attn': block.ln_attn, 'ln_mlp': block.ln_mlp}
4043
else:
41-
if block.config.parallel_attn:
44+
if getattr(block.config, 'parallel_attn', False):
4245
return {'input_layernorm': block.input_layernorm}
4346
else:
44-
return {'post_attention_layernorm': block.post_attention_layernorm}
47+
return {
48+
'input_layernorm': block.input_layernorm,
49+
'post_attention_layernorm': block.post_attention_layernorm,
50+
}
4551

4652
def get_subsets_in_block(self, block):
47-
if block.config.architectures[0] == 'RWForCausalLM':
48-
new_decoder_architecture = False
49-
elif block.config.architectures[0] == 'FalconForCausalLM':
50-
new_decoder_architecture = True
51-
if new_decoder_architecture:
53+
new_arch = self._is_new_decoder_architecture()
54+
55+
if new_arch:
5256
subset1 = {
5357
'layers': {
5458
'self_attention.query_key_value': (
@@ -79,7 +83,7 @@ def get_subsets_in_block(self, block):
7983
'inspect': block.self_attention.query_key_value,
8084
'has_kwargs': False,
8185
}
82-
if block.config.parallel_attn:
86+
if getattr(block.config, 'parallel_attn', False):
8387
subset3 = {
8488
'layers': {'mlp.dense_h_to_4h': block.mlp.dense_h_to_4h},
8589
'prev_op': [block.input_layernorm],

0 commit comments

Comments
 (0)