-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpack.py
More file actions
97 lines (79 loc) · 2.85 KB
/
pack.py
File metadata and controls
97 lines (79 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import struct
import numpy as np
bs = list(range(33, 127)) + list(range(161, 173)) + list(range(174, 256))
cs = bs[:]
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
u2b = {chr(c): b for b, c in zip(bs, cs)}
ids = {bytes([b]): i for i, b in enumerate(bs)}
byte_to_token = [0] * 256
token_to_byte = [0] * 256
for i, b in enumerate(bs):
byte_to_token[b] = i
token_to_byte[i] = b
merges = []
for line in open("merges.txt", encoding="utf-8"):
line = line.strip()
if not line or line.startswith("#version:"):
continue
a, b = line.rsplit(" ", 1)
a = bytes(u2b[c] for c in a)
b = bytes(u2b[c] for c in b)
merges.append((ids[a], ids[b]))
ids[a + b] = 256 + len(merges) - 1
from transformers import AutoModelForCausalLM
ID = "roneneldan/TinyStories-1M"
model = AutoModelForCausalLM.from_pretrained(ID)
sd = model.state_dict()
def q(key):
t = sd[key].float().numpy()
scale = max(np.abs(t).max() / 127.0, 1e-8)
return np.clip(np.round(t / scale), -128, 127).astype(np.int8), scale
def f32(key):
return sd[key].float().numpy().flatten().tobytes()
with open("emlm.bin", "wb") as f:
f.write(b"EMLM")
f.write(struct.pack("<H", 50257)) # vocab size
f.write(struct.pack("<H", 2048)) # max position
f.write(struct.pack("<H", 64)) # hidden size
f.write(struct.pack("<H", 256)) # intermediate size
f.write(struct.pack("<H", 8)) # num layers
f.write(struct.pack("<H", 16)) # num heads
f.write(struct.pack("<H", len(merges))) # num merges
for t in byte_to_token:
f.write(struct.pack("<H", t))
for t in token_to_byte:
f.write(struct.pack("<B", t))
for a, b in merges:
f.write(struct.pack("<HH", a, b))
qi, s = q("transformer.wte.weight")
f.write(struct.pack("<f", s))
f.write(qi.tobytes())
qi, s = q("transformer.wpe.weight")
f.write(struct.pack("<f", s))
f.write(qi.tobytes())
for i in range(8):
p = f"transformer.h.{i}"
f.write(f32(f"{p}.ln_1.weight"))
f.write(f32(f"{p}.ln_1.bias"))
f.write(f32(f"{p}.ln_2.weight"))
f.write(f32(f"{p}.ln_2.bias"))
for w in ["q_proj", "k_proj", "v_proj", "out_proj"]:
qi, s = q(f"{p}.attn.attention.{w}.weight")
f.write(struct.pack("<f", s))
f.write(qi.tobytes())
f.write(f32(f"{p}.attn.attention.out_proj.bias"))
for w in ["c_fc", "c_proj"]:
qi, s = q(f"{p}.mlp.{w}.weight")
f.write(struct.pack("<f", s))
f.write(qi.tobytes())
f.write(f32(f"{p}.mlp.c_fc.bias"))
f.write(f32(f"{p}.mlp.c_proj.bias"))
f.write(f32("transformer.ln_f.weight"))
f.write(f32("transformer.ln_f.bias"))
import os
print(f"emlm.bin: {os.path.getsize("emlm.bin") / 1024 / 1024:.2f} MB")