Skip to content

Commit 9e07cb0

Browse files
committed
feat: support SwiGLU
1 parent b553124 commit 9e07cb0

4 files changed

Lines changed: 33 additions & 3 deletions

File tree

cs336_basics/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
66
self.num_embeddings = num_embeddings
77
self.embedding_dim = embedding_dim
88
self.weights = torch.nn.Parameter(torch.empty(num_embeddings, embedding_dim))
9-
self.weights = torch.nn.init.trunc_normal_(self.weights, mean=0.0, std=1.0, a=-3.0, b=3.0)
9+
torch.nn.init.trunc_normal_(self.weights, mean=0.0, std=1.0, a=-3.0, b=3.0)
1010
self.device = device
1111
self.dtype = dtype
1212

cs336_basics/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self, in_features, out_features, device=None, dtype=None):
77
self.out_features = out_features
88
self.weights = torch.nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
99
std = (2.0 / (in_features + out_features)) ** 0.5
10-
self.weights = torch.nn.init.trunc_normal_(self.weights, mean=0.0, std=std, a=-3*std, b=3*std)
10+
torch.nn.init.trunc_normal_(self.weights, mean=0.0, std=std, a=-3*std, b=3*std)
1111
self.device = device
1212
self.dtype = dtype
1313

cs336_basics/swiglu.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
3+
class SwiGLU(torch.nn.Module):
4+
def __init__(self, d_model: int, d_ff= None, device=None, dtype=None):
5+
super().__init__()
6+
self.d_model = d_model
7+
if d_ff is None:
8+
d_ff_raw = int((8/3) * d_model)
9+
d_ff = ((d_ff_raw + 32) // 64) * 64
10+
self.d_ff = d_ff
11+
12+
self.w1 = torch.nn.Parameter(torch.empty((self.d_ff, self.d_model), device=device, dtype=dtype))
13+
self.w3 = torch.nn.Parameter(torch.empty((self.d_ff, self.d_model), device=device, dtype=dtype))
14+
self.w2 = torch.nn.Parameter(torch.empty((self.d_model, self.d_ff), device=device, dtype=dtype))
15+
std_w1_w3 = (2.0 / (self.d_model + self.d_ff)) ** 0.5
16+
std_w2 = (2.0 / (self.d_ff + self.d_model)) ** 0.5
17+
torch.nn.init.trunc_normal_(self.w1, mean=0.0, std=std_w1_w3, a=-3*std_w1_w3, b=3*std_w1_w3)
18+
torch.nn.init.trunc_normal_(self.w3, mean=0.0, std=std_w1_w3, a=-3*std_w1_w3, b=3*std_w1_w3)
19+
torch.nn.init.trunc_normal_(self.w2, mean=0.0, std=std_w2, a=-3*std_w2, b=3*std_w2)
20+
21+
def forward(self, x: torch.Tensor) -> torch.Tensor:
22+
w1x = x @ self.w1.T
23+
w3x = x @ self.w3.T
24+
silu_w1x = w1x * torch.sigmoid(w1x)
25+
gated = silu_w1x * w3x
26+
output = gated @ self.w2.T
27+
return output

tests/adapters.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cs336_basics.linear import Linear
1515
from cs336_basics.embedding import Embedding
1616
from cs336_basics.rmsnorm import RMSNorm
17+
from cs336_basics.swiglu import SwiGLU
1718

1819

1920
def run_linear(
@@ -91,7 +92,9 @@ def run_swiglu(
9192
# swiglu.w1.weight.data = w1_weight
9293
# swiglu.w2.weight.data = w2_weight
9394
# swiglu.w3.weight.data = w3_weight
94-
raise NotImplementedError
95+
swiglu = SwiGLU(d_model, d_ff=d_ff)
96+
swiglu.load_state_dict({ 'w1': w1_weight, 'w2': w2_weight, 'w3': w3_weight })
97+
return swiglu.forward(in_features)
9598

9699

97100
def run_scaled_dot_product_attention(

0 commit comments

Comments
 (0)