|
1 | 1 | import argparse |
2 | 2 | import os |
3 | 3 | from pathlib import Path |
| 4 | +import re |
4 | 5 |
|
5 | 6 | import torch |
6 | 7 | from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint |
| 8 | +from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel |
| 9 | + |
| 10 | + |
| 11 | +# The simple map of names for "automated" rules. |
| 12 | +NAME_MAP = { |
| 13 | + "_mlp._layer_1": "mlp.c_fc", |
| 14 | + "_mlp._layer_2": "mlp.c_proj", |
| 15 | + "layer_norm_1": "ln_1", |
| 16 | + "layer_norm_2": "ln_2", |
| 17 | + # "attention.dense": "attn.c_proj", |
| 18 | + "self_attn.dense": "attn.c_proj", |
| 19 | + # "self_attention.query_key_value": "attn.c_attn", |
| 20 | +} |
| 21 | + |
| 22 | + |
| 23 | +def convert_fast_llm_checkpoint(state_dict, config): |
| 24 | + # The converted output model. |
| 25 | + output_state_dict = {} |
| 26 | + if "window_size" in config: |
| 27 | + attention_window_size = config["window_size"] |
| 28 | + else: |
| 29 | + attention_window_size = config.get("attention_window_size", None) |
| 30 | + |
| 31 | + config = GPTBigCodeConfig( |
| 32 | + architectures=["GPTBigCodeLMHeadModel"], |
| 33 | + vocab_size=config["vocab_size"], |
| 34 | + n_positions=config["max_position_embeddings"], |
| 35 | + n_embd=config["hidden_size"], |
| 36 | + n_layer=config["num_layers"], |
| 37 | + n_head=config["num_attention_heads"], |
| 38 | + n_inner=config["ffn_hidden_size"], |
| 39 | + activation_function="gelu", # TODO |
| 40 | + multi_query=True, # TODO |
| 41 | + resid_pdrop=0.1, |
| 42 | + embd_pdrop=0.1, |
| 43 | + attn_pdrop=0.1, |
| 44 | + layer_norm_epsilon=1e-5, |
| 45 | + initializer_range=0.02, |
| 46 | + summary_type="cls_index", |
| 47 | + summary_use_proj=True, |
| 48 | + summary_activation=None, |
| 49 | + summary_proj_to_labels=True, |
| 50 | + summary_first_dropout=0.1, |
| 51 | + scale_attn_weights=True, |
| 52 | + use_cache=True, |
| 53 | + bos_token_id=0, # TODO: can we remove these? |
| 54 | + eos_token_id=0, |
| 55 | + attention_softmax_in_fp32=True, |
| 56 | + scale_attention_softmax_in_fp32=True, |
| 57 | + use_rotary_embeddings=config["use_rotary_embeddings"], |
| 58 | + rotary_embedding_scale=config["rotary_embedding_scale"], |
| 59 | + use_position_embeddings=config["use_position_embeddings"], |
| 60 | + attention_window_size=attention_window_size |
| 61 | + ) |
| 62 | + |
| 63 | + # Truncate the word embeddings to the vocab-size |
| 64 | + word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] |
| 65 | + output_state_dict["transformer.wte.weight"] = word_embeddings |
| 66 | + if config.use_position_embeddings: |
| 67 | + output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight") |
| 68 | + |
| 69 | + # Layer-0 is the word/position embeddings |
| 70 | + # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. |
| 71 | + # _layers.{layer_index}.{op}.{w/b} |
| 72 | + |
| 73 | + # Concatenate QKV matrix |
| 74 | + for layer_index in range(1, config.n_layer + 1): |
| 75 | + for weight_or_bias in ["weight", "bias"]: |
| 76 | + query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}") |
| 77 | + key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}") |
| 78 | + output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) |
| 79 | + |
| 80 | + # Extract the other ops |
| 81 | + layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") |
| 82 | + for name, value in state_dict.items(): |
| 83 | + m = layer_re.match(name) |
| 84 | + assert m is not None, f"Invalid layer name: {name}" |
| 85 | + |
| 86 | + # The index of the layer. |
| 87 | + layer_index = int(m.group(1)) |
| 88 | + # The name of the operation. |
| 89 | + op_name = m.group(2) |
| 90 | + # Is it a weight or a bias? |
| 91 | + weight_or_bias = m.group(3) |
| 92 | + |
| 93 | + # Final layernorm |
| 94 | + if op_name == "final_layernorm": |
| 95 | + assert layer_index == config.n_layer + 1 |
| 96 | + output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value |
| 97 | + else: |
| 98 | + output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value |
| 99 | + |
| 100 | + # For LM head, transformers' wants the matrix to weight embeddings. |
| 101 | + output_state_dict["lm_head.weight"] = word_embeddings |
| 102 | + |
| 103 | + return output_state_dict, config |
7 | 104 |
|
8 | 105 |
|
9 | 106 | def main(argv=None): |
10 | 107 | parser = argparse.ArgumentParser() |
11 | 108 | parser.add_argument( |
12 | 109 | "--checkpoint_dir", |
13 | 110 | type=Path, |
14 | | - # default="/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d", |
15 | | - help="Path where the converted model is saved" |
| 111 | + help="Path to the experiment directory", |
16 | 112 | ) |
17 | 113 | parser.add_argument( |
18 | 114 | "--save_dir", |
19 | 115 | type=Path, |
20 | | - # default="./", |
21 | 116 | help="Path where the converted model is saved" |
22 | 117 | ) |
23 | 118 | args = parser.parse_args(argv) |
24 | | - |
25 | | - print("start") |
26 | | - |
27 | | - # TODO(xrsrke): auto convert checkpoint_dir to Path |
28 | | - # checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" |
29 | | - # checkpoint_dir = Path(checkpoint_dir) |
30 | 119 |
|
31 | | - state_dict = merge_checkpoint(args.checkpoint_dir) |
| 120 | + state_dict, config = merge_checkpoint( |
| 121 | + args.checkpoint_dir, |
| 122 | + dummy_experiment_dir=None |
| 123 | + ) |
| 124 | + |
| 125 | + output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) |
32 | 126 |
|
| 127 | + print("Saving config") |
33 | 128 | save_dir = args.save_dir or args.checkpoint_dir / "converted" |
| 129 | + output_config.save_pretrained(save_dir) |
| 130 | + |
| 131 | + # Store the state_dict to file. |
34 | 132 | output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") |
35 | | - |
36 | 133 | print(f'Saving checkpoint to "{output_checkpoint_file}"') |
37 | | - torch.save(state_dict, output_checkpoint_file) |
| 134 | + torch.save(output_state_dict, output_checkpoint_file) |
38 | 135 | print(f'Done!') |
39 | | - |
40 | | - # # Compare |
41 | | - # def compare_state_dicts(dict1, dict2): |
42 | | - # # Compare keys |
43 | | - # if set(dict1.keys()) != set(dict2.keys()): |
44 | | - # return "Different keys" |
45 | | - |
46 | | - # # Compare shapes and values |
47 | | - # for key in dict1: |
48 | | - # if dict1[key].shape != dict2[key].shape: |
49 | | - # return f"Different shape for key: {key}" |
50 | | - # if not torch.allclose(dict1[key], dict2[key]): |
51 | | - # return f"Different values for key: {key}" |
52 | | - |
53 | | - # return "State dictionaries are identical" |
54 | | - |
55 | | - # ref_state_dict = torch.load("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth") |
56 | | - # result = compare_state_dicts(state_dict, ref_state_dict) |
57 | | - # print(result) |
58 | | - |
59 | 136 |
|
60 | 137 |
|
61 | 138 | if __name__ == "__main__": |
|
0 commit comments