|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | + |
| 17 | +from collections import defaultdict |
| 18 | +import torch |
| 19 | +from gemma import config |
| 20 | +from gemma import model as gemma_model |
| 21 | +import numpy as np |
| 22 | +import argparse |
| 23 | +import os |
| 24 | + |
| 25 | +# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch |
| 26 | + |
| 27 | +def check_file_exists(value): |
| 28 | + if not os.path.exists(str(value)): |
| 29 | + raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) |
| 30 | + return value |
| 31 | + |
| 32 | + |
| 33 | +def check_model_types(value): |
| 34 | + if str(value).lower() not in ["2b", "7b"]: |
| 35 | + raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) |
| 36 | + return value |
| 37 | + |
| 38 | + |
| 39 | +parser = argparse.ArgumentParser() |
| 40 | +parser.add_argument( |
| 41 | + "--tokenizer", |
| 42 | + dest="tokenizer", |
| 43 | + default="models/tokenizer.spm", |
| 44 | + help="Location of tokenizer file (.model or .spm)", |
| 45 | + type=check_file_exists, |
| 46 | +) |
| 47 | + |
| 48 | +parser.add_argument( |
| 49 | + "--weights", |
| 50 | + dest="weights", |
| 51 | + default="models/gemma-2b-it.ckpt", |
| 52 | + help="Location of input checkpoint file (.ckpt)", |
| 53 | + type=check_file_exists, |
| 54 | +) |
| 55 | + |
| 56 | +parser.add_argument( |
| 57 | + "--output_file", |
| 58 | + dest="output_file", |
| 59 | + default="2bit-f32.sbs", |
| 60 | + help="Location to write converted weights", |
| 61 | + type=str, |
| 62 | +) |
| 63 | + |
| 64 | +parser.add_argument( |
| 65 | + "--model_type", |
| 66 | + dest="model_type", |
| 67 | + default="2b", |
| 68 | + help="Model size / type (2b, 7b)", |
| 69 | + type=check_model_types, |
| 70 | +) |
| 71 | + |
| 72 | +args = parser.parse_args() |
| 73 | + |
| 74 | + |
| 75 | +def expand_qkv(qkv_proj: np.array) -> np.array: |
| 76 | + """This won't be needed anymore when MQA is implemented""" |
| 77 | + assert qkv_proj.shape == (2560, 2048) |
| 78 | + qkv = qkv_proj.reshape((10, 256, 2048)) |
| 79 | + |
| 80 | + q_proj = qkv[:8].reshape((1,8,256,2048)) |
| 81 | + kv_proj = qkv[8:] |
| 82 | + kv_proj = kv_proj[:, np.newaxis, :, :] |
| 83 | + kv_proj = np.repeat(kv_proj, 8, axis=1) |
| 84 | + |
| 85 | + qkv = np.concatenate([q_proj, kv_proj]) |
| 86 | + qkv = np.transpose(qkv, axes=[1,0,2,3]) |
| 87 | + return qkv |
| 88 | + |
| 89 | +TRANSFORMATIONS = { |
| 90 | + "2b":defaultdict( |
| 91 | + lambda: lambda x: x, |
| 92 | + { |
| 93 | + "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), |
| 94 | + "self_attn.qkv_proj.weight": expand_qkv, |
| 95 | + "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), |
| 96 | + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| 97 | + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| 98 | + "mlp.down_proj.weight": lambda x: x, |
| 99 | + } |
| 100 | + ), |
| 101 | + "7b":defaultdict( |
| 102 | + lambda: lambda x: x, |
| 103 | + { |
| 104 | + "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), |
| 105 | + "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), |
| 106 | + "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), |
| 107 | + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| 108 | + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| 109 | + "mlp.down_proj.weight": lambda x: x, |
| 110 | + } |
| 111 | + ), |
| 112 | +} |
| 113 | + |
| 114 | +VALIDATIONS = { |
| 115 | + "2b": { |
| 116 | + "embedder.weight": lambda x: x.shape == (256128, 2048), |
| 117 | + "model.norm.weight": lambda x: x.shape == (2048,), |
| 118 | + "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), |
| 119 | + "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), |
| 120 | + "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| 121 | + "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| 122 | + "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), |
| 123 | + "input_layernorm.weight": lambda x: x.shape == (2048,), |
| 124 | + "post_attention_layernorm.weight": lambda x: x.shape == (2048,), |
| 125 | + }, |
| 126 | + "7b": { |
| 127 | + "embedder.weight": lambda x: x.shape == (256128, 3072), |
| 128 | + "model.norm.weight": lambda x: x.shape == (3072,), |
| 129 | + "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), |
| 130 | + "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), |
| 131 | + "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| 132 | + "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| 133 | + "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), |
| 134 | + "input_layernorm.weight": lambda x: x.shape == (3072,), |
| 135 | + "post_attention_layernorm.weight": lambda x: x.shape == (3072,), |
| 136 | + }, |
| 137 | +} |
| 138 | + |
| 139 | + |
| 140 | +def param_names(num_hidden_layers: int): |
| 141 | + """Return parameter names in the order they are expected for deserialization.""" |
| 142 | + |
| 143 | + # note *weight_scaler params are ignored in the forward computation unless |
| 144 | + # quantization is being used. |
| 145 | + # |
| 146 | + # since we are working with the full precision weights as input, don't |
| 147 | + # include these in the parameters being iterated over. |
| 148 | + |
| 149 | + # fmt: off |
| 150 | + names = [ |
| 151 | + ("embedder.weight", ) * 2, # embedder_input_embedding |
| 152 | + ("model.norm.weight", ) * 2 # final_norm_scale |
| 153 | + ] |
| 154 | + layer_params = [ |
| 155 | + "self_attn.o_proj.weight", # attn_vec_einsum_w |
| 156 | + "self_attn.qkv_proj.weight", # qkv_einsum_w |
| 157 | + "mlp.gate_proj.weight", # gating_einsum_w |
| 158 | + "mlp.up_proj.weight", |
| 159 | + "mlp.down_proj.weight", # linear_w |
| 160 | + "input_layernorm.weight", # pre_attention_norm_scale |
| 161 | + "post_attention_layernorm.weight", # pre_ffw_norm_scale |
| 162 | + ] |
| 163 | + # fmt: on |
| 164 | + for layer in range(num_hidden_layers): |
| 165 | + for layer_param in layer_params: |
| 166 | + names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] |
| 167 | + return names |
| 168 | + |
| 169 | + |
| 170 | +def convert_weights(): |
| 171 | + model_type = args.model_type |
| 172 | + output_file = args.output_file |
| 173 | + |
| 174 | + model_config = config.get_model_config(model_type) |
| 175 | + model_config.dtype = "float32" |
| 176 | + model_config.tokenizer = args.tokenizer |
| 177 | + device = torch.device("cpu") |
| 178 | + torch.set_default_dtype(torch.float) |
| 179 | + model = gemma_model.GemmaForCausalLM(model_config) |
| 180 | + |
| 181 | + model.load_weights(args.weights) |
| 182 | + model.to(device).eval() |
| 183 | + |
| 184 | + model_dict = dict(model.named_parameters()) |
| 185 | + param_order = param_names(model_config.num_hidden_layers) |
| 186 | + |
| 187 | + all_ok = True |
| 188 | + print("Checking transformations ...") |
| 189 | + for name, layer_name in param_order: |
| 190 | + arr = model_dict[name].detach().numpy() |
| 191 | + arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| 192 | + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
| 193 | + |
| 194 | + if check == "FAILED": |
| 195 | + all_ok = False |
| 196 | + print(f" {name : <60}{str(arr.shape) : <20}{check}") |
| 197 | + |
| 198 | + if all_ok: |
| 199 | + print("Writing parameters ...") |
| 200 | + gate = None |
| 201 | + with open(output_file, "wb") as bin_handle: |
| 202 | + for name, layer_name in param_order: |
| 203 | + arr = model_dict[name].detach().numpy() |
| 204 | + arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| 205 | + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
| 206 | + print(f" {name : <60}{str(arr.shape) : <20}{check}") |
| 207 | + arr.flatten().astype(np.float32).tofile(bin_handle) |
| 208 | + |
| 209 | + |
| 210 | +if __name__ == "__main__": |
| 211 | + convert_weights() |
| 212 | + print("Done") |
0 commit comments