-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
167 lines (125 loc) · 5.1 KB
/
model.py
File metadata and controls
167 lines (125 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from parameters.hyper import DEVICE
from parameters.nlp import END_TOKEN_IDX
import torch
from torch import nn
from torch.nn import functional as F
from typing import cast
from parameters.model import ModelParams
class PositionalEncoding(nn.Module):
pe: torch.Tensor
def __init__(self, D: int, denom_base=10_000, max_len=1024):
super().__init__()
pos = torch.arange(max_len)
denom = denom_base ** (torch.arange(0, D, 2) / D)
pe_sin = torch.sin(pos.unsqueeze(1) / denom)
pe_cos = torch.cos(pos.unsqueeze(1) / denom)
pe = torch.empty(max_len, D)
pe[:, 0::2] = pe_sin
pe[:, 1::2] = pe_cos
self.register_buffer("pe", pe)
def forward(self, X): # X = [E_567, \dots, E50]; E_i \in \mathbb{R}^params.D
return self.pe[: X.size()[-2]]
class AttentionHead(nn.Module):
def __init__(self, params: ModelParams):
super().__init__()
self.W_q = nn.Linear(params.D, params.D_v)
self.W_k = nn.Linear(params.D, params.D_v)
self.W_v = nn.Linear(params.D, params.D_v)
def forward(self, X):
Q = self.W_q(X)
K = self.W_k(X)
V = self.W_v(X)
return F.scaled_dot_product_attention(Q, K, V, is_causal=True)
class MHANaive(nn.Module):
def __init__(self, params: ModelParams):
super().__init__()
self.attn_heads = nn.ModuleList(
[AttentionHead(params) for _ in range(params.H)]
)
self.W_o = nn.Linear(params.D, params.D)
def forward(self, X):
return self.W_o(torch.cat([H(X) for H in self.attn_heads], dim=-1))
class MHAFused(nn.Module):
def __init__(self, params: ModelParams):
super().__init__()
self.H = params.H
self.D_v = params.D_v
self.W_Q = nn.Linear(params.D, params.D)
self.W_K = nn.Linear(params.D, params.D)
self.W_V = nn.Linear(params.D, params.D)
self.W_O = nn.Linear(params.D, params.D)
def forward(self, X):
"""
X \in \mathbb{R}^(N \times D)
Fused matmul on GPU for efficiency
Views for isolated attention
Spam `contiguous` for the clock is ticking
"""
Q_hat = cast(torch.Tensor, self.W_Q(X))
K_hat = cast(torch.Tensor, self.W_K(X))
V_hat = cast(torch.Tensor, self.W_V(X))
Q_wave = Q_hat.view(*Q_hat.size()[:-1], self.H, self.D_v)
Q_dot = Q_wave.transpose(-2, -3).contiguous()
K_wave = K_hat.view(*K_hat.size()[:-1], self.H, self.D_v)
K_dot = K_wave.transpose(-2, -3).contiguous()
V_wave = V_hat.view(*V_hat.size()[:-1], self.H, self.D_v)
V_dot = V_wave.transpose(-2, -3).contiguous()
A_dot = F.scaled_dot_product_attention(Q_dot, K_dot, V_dot, is_causal=True)
A_wave = A_dot.transpose(-2, -3).contiguous()
A = A_wave.view(*A_wave.size()[:-2], self.H * self.D_v)
return self.W_O(A)
class TransformerLayer(nn.Module):
def __init__(self, params: ModelParams):
super().__init__()
self.mha = MHAFused(params)
# (??) = nn.MultiheadAttention(params.D, params.H)
self.dropout_mha = nn.Dropout(params.dropout)
self.layer_norm_mha = nn.LayerNorm(params.D)
self.ffn = nn.Sequential(
nn.Linear(params.D, 4 * params.D),
nn.SiLU(),
nn.Linear(4 * params.D, params.D),
)
self.dropout_ffn = nn.Dropout(params.dropout)
self.layer_norm_ffn = nn.LayerNorm(params.D)
def forward(self, X):
# pre-norm
X1 = self.dropout_mha(self.mha(self.layer_norm_mha(X))) + X
X2 = self.dropout_ffn(self.ffn(self.layer_norm_ffn(X1))) + X1
return X2
class DecoderTransformer(nn.Module):
def __init__(self, params: ModelParams):
super().__init__()
self.embedding = nn.Embedding(params.N, params.D)
self.pe = PositionalEncoding(params.D)
self.dropout = nn.Dropout(params.dropout)
self.tf_stack = nn.Sequential(
*(TransformerLayer(params) for _ in range(params.L))
)
self.layer_norm = nn.LayerNorm(params.D)
self.out_proj = nn.Linear(params.D, params.N)
# weight tying
self.out_proj.weight = self.embedding.weight
def forward(self, X):
X = self.embedding(X)
X = X + self.pe(X)
X = self.dropout(X)
X = self.tf_stack(X)
X = self.layer_norm(X)
return self.out_proj(X)
def save(self, file_name: str):
torch.save(self.state_dict(), file_name)
def load(self, file_name: str):
self.load_state_dict(torch.load(file_name, map_location=torch.device("cpu")))
@torch.no_grad()
def generate(self, prefix, max_tokens=512):
tokens = torch.tensor(prefix, device=DEVICE).unsqueeze(0)
for _ in range(max_tokens):
logits = self(tokens)[:, -1, :]
next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1).item()
tokens = torch.cat(
[tokens, torch.tensor([[next_token]], device=DEVICE)], dim=-1
)
if next_token == END_TOKEN_IDX:
break
return tokens[0].tolist()