Skip to content

Commit 1507798

Browse files
committed
Refactor GPTBigCode model conversion code
1 parent 6bdf78a commit 1507798

4 files changed

Lines changed: 253 additions & 230 deletions

File tree

src/transformers/generation/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2469,9 +2469,7 @@ def greedy_search(
24692469
# Store scores, attentions and hidden_states when required
24702470
if return_dict_in_generate:
24712471
if output_scores:
2472-
scores += (next_tokens_scores,) if outputs.logits.shape[1] == 1 else (
2473-
outputs.logits,
2474-
)
2472+
scores += (next_tokens_scores,)
24752473
if output_attentions:
24762474
decoder_attentions += (
24772475
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)

src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py

Lines changed: 109 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,138 @@
11
import argparse
22
import os
33
from pathlib import Path
4+
import re
45

56
import torch
67
from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint
8+
from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel
9+
10+
11+
# The simple map of names for "automated" rules.
12+
NAME_MAP = {
13+
"_mlp._layer_1": "mlp.c_fc",
14+
"_mlp._layer_2": "mlp.c_proj",
15+
"layer_norm_1": "ln_1",
16+
"layer_norm_2": "ln_2",
17+
# "attention.dense": "attn.c_proj",
18+
"self_attn.dense": "attn.c_proj",
19+
# "self_attention.query_key_value": "attn.c_attn",
20+
}
21+
22+
23+
def convert_fast_llm_checkpoint(state_dict, config):
24+
# The converted output model.
25+
output_state_dict = {}
26+
if "window_size" in config:
27+
attention_window_size = config["window_size"]
28+
else:
29+
attention_window_size = config.get("attention_window_size", None)
30+
31+
config = GPTBigCodeConfig(
32+
architectures=["GPTBigCodeLMHeadModel"],
33+
vocab_size=config["vocab_size"],
34+
n_positions=config["max_position_embeddings"],
35+
n_embd=config["hidden_size"],
36+
n_layer=config["num_layers"],
37+
n_head=config["num_attention_heads"],
38+
n_inner=config["ffn_hidden_size"],
39+
activation_function="gelu", # TODO
40+
multi_query=True, # TODO
41+
resid_pdrop=0.1,
42+
embd_pdrop=0.1,
43+
attn_pdrop=0.1,
44+
layer_norm_epsilon=1e-5,
45+
initializer_range=0.02,
46+
summary_type="cls_index",
47+
summary_use_proj=True,
48+
summary_activation=None,
49+
summary_proj_to_labels=True,
50+
summary_first_dropout=0.1,
51+
scale_attn_weights=True,
52+
use_cache=True,
53+
bos_token_id=0, # TODO: can we remove these?
54+
eos_token_id=0,
55+
attention_softmax_in_fp32=True,
56+
scale_attention_softmax_in_fp32=True,
57+
use_rotary_embeddings=config["use_rotary_embeddings"],
58+
rotary_embedding_scale=config["rotary_embedding_scale"],
59+
use_position_embeddings=config["use_position_embeddings"],
60+
attention_window_size=attention_window_size
61+
)
62+
63+
# Truncate the word embeddings to the vocab-size
64+
word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :]
65+
output_state_dict["transformer.wte.weight"] = word_embeddings
66+
if config.use_position_embeddings:
67+
output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight")
68+
69+
# Layer-0 is the word/position embeddings
70+
# Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1.
71+
# _layers.{layer_index}.{op}.{w/b}
72+
73+
# Concatenate QKV matrix
74+
for layer_index in range(1, config.n_layer + 1):
75+
for weight_or_bias in ["weight", "bias"]:
76+
query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}")
77+
key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}")
78+
output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0)
79+
80+
# Extract the other ops
81+
layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
82+
for name, value in state_dict.items():
83+
m = layer_re.match(name)
84+
assert m is not None, f"Invalid layer name: {name}"
85+
86+
# The index of the layer.
87+
layer_index = int(m.group(1))
88+
# The name of the operation.
89+
op_name = m.group(2)
90+
# Is it a weight or a bias?
91+
weight_or_bias = m.group(3)
92+
93+
# Final layernorm
94+
if op_name == "final_layernorm":
95+
assert layer_index == config.n_layer + 1
96+
output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value
97+
else:
98+
output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value
99+
100+
# For LM head, transformers' wants the matrix to weight embeddings.
101+
output_state_dict["lm_head.weight"] = word_embeddings
102+
103+
return output_state_dict, config
7104

8105

9106
def main(argv=None):
10107
parser = argparse.ArgumentParser()
11108
parser.add_argument(
12109
"--checkpoint_dir",
13110
type=Path,
14-
# default="/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d",
15-
help="Path where the converted model is saved"
111+
help="Path to the experiment directory",
16112
)
17113
parser.add_argument(
18114
"--save_dir",
19115
type=Path,
20-
# default="./",
21116
help="Path where the converted model is saved"
22117
)
23118
args = parser.parse_args(argv)
24-
25-
print("start")
26-
27-
# TODO(xrsrke): auto convert checkpoint_dir to Path
28-
# checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d"
29-
# checkpoint_dir = Path(checkpoint_dir)
30119

31-
state_dict = merge_checkpoint(args.checkpoint_dir)
120+
state_dict, config = merge_checkpoint(
121+
args.checkpoint_dir,
122+
dummy_experiment_dir=None
123+
)
124+
125+
output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config)
32126

127+
print("Saving config")
33128
save_dir = args.save_dir or args.checkpoint_dir / "converted"
129+
output_config.save_pretrained(save_dir)
130+
131+
# Store the state_dict to file.
34132
output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin")
35-
36133
print(f'Saving checkpoint to "{output_checkpoint_file}"')
37-
torch.save(state_dict, output_checkpoint_file)
134+
torch.save(output_state_dict, output_checkpoint_file)
38135
print(f'Done!')
39-
40-
# # Compare
41-
# def compare_state_dicts(dict1, dict2):
42-
# # Compare keys
43-
# if set(dict1.keys()) != set(dict2.keys()):
44-
# return "Different keys"
45-
46-
# # Compare shapes and values
47-
# for key in dict1:
48-
# if dict1[key].shape != dict2[key].shape:
49-
# return f"Different shape for key: {key}"
50-
# if not torch.allclose(dict1[key], dict2[key]):
51-
# return f"Different values for key: {key}"
52-
53-
# return "State dictionaries are identical"
54-
55-
# ref_state_dict = torch.load("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth")
56-
# result = compare_state_dicts(state_dict, ref_state_dict)
57-
# print(result)
58-
59136

60137

61138
if __name__ == "__main__":

0 commit comments

Comments
 (0)