Skip to content

Commit a0f316d

Browse files
Merge pull request #95 from google:conversion
PiperOrigin-RevId: 615448039
2 parents 5fa2eb1 + f520e5c commit a0f316d

1 file changed

Lines changed: 212 additions & 0 deletions

File tree

util/convert_weights.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2024 Google LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from collections import defaultdict
18+
import torch
19+
from gemma import config
20+
from gemma import model as gemma_model
21+
import numpy as np
22+
import argparse
23+
import os
24+
25+
# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch
26+
27+
def check_file_exists(value):
28+
if not os.path.exists(str(value)):
29+
raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value)
30+
return value
31+
32+
33+
def check_model_types(value):
34+
if str(value).lower() not in ["2b", "7b"]:
35+
raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value)
36+
return value
37+
38+
39+
parser = argparse.ArgumentParser()
40+
parser.add_argument(
41+
"--tokenizer",
42+
dest="tokenizer",
43+
default="models/tokenizer.spm",
44+
help="Location of tokenizer file (.model or .spm)",
45+
type=check_file_exists,
46+
)
47+
48+
parser.add_argument(
49+
"--weights",
50+
dest="weights",
51+
default="models/gemma-2b-it.ckpt",
52+
help="Location of input checkpoint file (.ckpt)",
53+
type=check_file_exists,
54+
)
55+
56+
parser.add_argument(
57+
"--output_file",
58+
dest="output_file",
59+
default="2bit-f32.sbs",
60+
help="Location to write converted weights",
61+
type=str,
62+
)
63+
64+
parser.add_argument(
65+
"--model_type",
66+
dest="model_type",
67+
default="2b",
68+
help="Model size / type (2b, 7b)",
69+
type=check_model_types,
70+
)
71+
72+
args = parser.parse_args()
73+
74+
75+
def expand_qkv(qkv_proj: np.array) -> np.array:
76+
"""This won't be needed anymore when MQA is implemented"""
77+
assert qkv_proj.shape == (2560, 2048)
78+
qkv = qkv_proj.reshape((10, 256, 2048))
79+
80+
q_proj = qkv[:8].reshape((1,8,256,2048))
81+
kv_proj = qkv[8:]
82+
kv_proj = kv_proj[:, np.newaxis, :, :]
83+
kv_proj = np.repeat(kv_proj, 8, axis=1)
84+
85+
qkv = np.concatenate([q_proj, kv_proj])
86+
qkv = np.transpose(qkv, axes=[1,0,2,3])
87+
return qkv
88+
89+
TRANSFORMATIONS = {
90+
"2b":defaultdict(
91+
lambda: lambda x: x,
92+
{
93+
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
94+
"self_attn.qkv_proj.weight": expand_qkv,
95+
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
96+
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
97+
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
98+
"mlp.down_proj.weight": lambda x: x,
99+
}
100+
),
101+
"7b":defaultdict(
102+
lambda: lambda x: x,
103+
{
104+
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
105+
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
106+
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
107+
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
108+
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
109+
"mlp.down_proj.weight": lambda x: x,
110+
}
111+
),
112+
}
113+
114+
VALIDATIONS = {
115+
"2b": {
116+
"embedder.weight": lambda x: x.shape == (256128, 2048),
117+
"model.norm.weight": lambda x: x.shape == (2048,),
118+
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
119+
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
120+
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
121+
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
122+
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
123+
"input_layernorm.weight": lambda x: x.shape == (2048,),
124+
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
125+
},
126+
"7b": {
127+
"embedder.weight": lambda x: x.shape == (256128, 3072),
128+
"model.norm.weight": lambda x: x.shape == (3072,),
129+
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
130+
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
131+
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
132+
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
133+
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
134+
"input_layernorm.weight": lambda x: x.shape == (3072,),
135+
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
136+
},
137+
}
138+
139+
140+
def param_names(num_hidden_layers: int):
141+
"""Return parameter names in the order they are expected for deserialization."""
142+
143+
# note *weight_scaler params are ignored in the forward computation unless
144+
# quantization is being used.
145+
#
146+
# since we are working with the full precision weights as input, don't
147+
# include these in the parameters being iterated over.
148+
149+
# fmt: off
150+
names = [
151+
("embedder.weight", ) * 2, # embedder_input_embedding
152+
("model.norm.weight", ) * 2 # final_norm_scale
153+
]
154+
layer_params = [
155+
"self_attn.o_proj.weight", # attn_vec_einsum_w
156+
"self_attn.qkv_proj.weight", # qkv_einsum_w
157+
"mlp.gate_proj.weight", # gating_einsum_w
158+
"mlp.up_proj.weight",
159+
"mlp.down_proj.weight", # linear_w
160+
"input_layernorm.weight", # pre_attention_norm_scale
161+
"post_attention_layernorm.weight", # pre_ffw_norm_scale
162+
]
163+
# fmt: on
164+
for layer in range(num_hidden_layers):
165+
for layer_param in layer_params:
166+
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
167+
return names
168+
169+
170+
def convert_weights():
171+
model_type = args.model_type
172+
output_file = args.output_file
173+
174+
model_config = config.get_model_config(model_type)
175+
model_config.dtype = "float32"
176+
model_config.tokenizer = args.tokenizer
177+
device = torch.device("cpu")
178+
torch.set_default_dtype(torch.float)
179+
model = gemma_model.GemmaForCausalLM(model_config)
180+
181+
model.load_weights(args.weights)
182+
model.to(device).eval()
183+
184+
model_dict = dict(model.named_parameters())
185+
param_order = param_names(model_config.num_hidden_layers)
186+
187+
all_ok = True
188+
print("Checking transformations ...")
189+
for name, layer_name in param_order:
190+
arr = model_dict[name].detach().numpy()
191+
arr = TRANSFORMATIONS[model_type][layer_name](arr)
192+
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
193+
194+
if check == "FAILED":
195+
all_ok = False
196+
print(f" {name : <60}{str(arr.shape) : <20}{check}")
197+
198+
if all_ok:
199+
print("Writing parameters ...")
200+
gate = None
201+
with open(output_file, "wb") as bin_handle:
202+
for name, layer_name in param_order:
203+
arr = model_dict[name].detach().numpy()
204+
arr = TRANSFORMATIONS[model_type][layer_name](arr)
205+
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
206+
print(f" {name : <60}{str(arr.shape) : <20}{check}")
207+
arr.flatten().astype(np.float32).tofile(bin_handle)
208+
209+
210+
if __name__ == "__main__":
211+
convert_weights()
212+
print("Done")

0 commit comments

Comments
 (0)