Skip to content

Commit c797834

Browse files
committed
feat: support Transformer LM
1 parent 99601cf commit c797834

3 files changed

Lines changed: 55 additions & 4 deletions

File tree

cs336_basics/transformer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from cs336_basics.embedding import Embedding
3+
from cs336_basics.transformer_block import TransformerBlock
4+
from cs336_basics.rmsnorm import RMSNorm
5+
from cs336_basics.linear import Linear
6+
from jaxtyping import Float
7+
8+
class Transformer(torch.nn.Module):
9+
def __init__(
10+
self,
11+
vocab_size: int,
12+
context_length: int,
13+
num_layers: int,
14+
d_model: int,
15+
num_heads: int,
16+
d_ff: int,
17+
rope_theta: float,
18+
device=None,
19+
dtype=None,
20+
):
21+
super().__init__()
22+
self.token_embedding = Embedding(vocab_size, d_model, device=device, dtype=dtype)
23+
self.layers = torch.nn.ModuleList([TransformerBlock(d_model, num_heads, d_ff, context_length, rope_theta, device=device, dtype=dtype) for _ in range(num_layers)])
24+
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
25+
self.lm_head = Linear(d_model, vocab_size)
26+
27+
def forward(self, x: Float[torch.Tensor, "batch_size sequence_length"]) -> Float[torch.Tensor, "batch_size sequence_length vocab_size"]:
28+
step_result = self.token_embedding(x)
29+
for layer in self.layers:
30+
step_result = layer(step_result)
31+
step_result = self.ln_final(step_result)
32+
step_result = self.lm_head(step_result)
33+
return step_result

cs336_basics/transformer_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from cs336_basics.swiglu import SwiGLU
55
from jaxtyping import Float
66

7-
class TranformerBlock(torch.nn.Module):
7+
class TransformerBlock(torch.nn.Module):
88
def __init__(
99
self,
1010
d_model: int,

tests/adapters.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from cs336_basics.softmax import softmax
2020
from cs336_basics.scaled_dot_product_attention import scaled_dot_product_attention
2121
from cs336_basics.multihead_self_attention import MultiHeadSelfAttention, MultiHeadSelfAttentionWithRoPE
22-
from cs336_basics.transformer_block import TranformerBlock
22+
from cs336_basics.transformer_block import TransformerBlock
23+
from cs336_basics.transformer import Transformer
2324

2425
def run_linear(
2526
d_in: int,
@@ -305,7 +306,7 @@ class and pass that instead.
305306
Float[Tensor, "batch sequence_length d_model"] Tensor with the output of
306307
running the Transformer block on the input features while using RoPE.
307308
"""
308-
transformer_block = TranformerBlock(d_model, num_heads, d_ff, max_seq_len, theta, device=in_features.device, dtype=in_features.dtype)
309+
transformer_block = TransformerBlock(d_model, num_heads, d_ff, max_seq_len, theta, device=in_features.device, dtype=in_features.dtype)
309310
transformer_block.load_state_dict({
310311
"mhsa.W_q": weights["attn.q_proj.weight"],
311312
"mhsa.W_k": weights["attn.k_proj.weight"],
@@ -399,7 +400,24 @@ def run_transformer_lm(
399400
Float[Tensor, "batch_size sequence_length vocab_size"]: Tensor with the predicted unnormalized
400401
next-word distribution for each token.
401402
"""
402-
raise NotImplementedError
403+
transformer = Transformer(vocab_size, context_length, num_layers, d_model, num_heads, d_ff, rope_theta, device=in_indices.device)
404+
state_dict = {
405+
'token_embedding.weights': weights['token_embeddings.weight'],
406+
'ln_final.g': weights['ln_final.weight'],
407+
'lm_head.weights': weights['lm_head.weight']
408+
}
409+
for i in range(num_layers):
410+
state_dict[f"layers.{i}.mhsa.W_q"] = weights[f'layers.{i}.attn.q_proj.weight']
411+
state_dict[f"layers.{i}.mhsa.W_k"] = weights[f'layers.{i}.attn.k_proj.weight']
412+
state_dict[f"layers.{i}.mhsa.W_v"] = weights[f'layers.{i}.attn.v_proj.weight']
413+
state_dict[f"layers.{i}.mhsa.W_o"] = weights[f'layers.{i}.attn.output_proj.weight']
414+
state_dict[f"layers.{i}.ln1.g"] = weights[f'layers.{i}.ln1.weight']
415+
state_dict[f"layers.{i}.ln2.g"] = weights[f'layers.{i}.ln2.weight']
416+
state_dict[f"layers.{i}.ffn.w1"] = weights[f"layers.{i}.ffn.w1.weight"]
417+
state_dict[f"layers.{i}.ffn.w2"] = weights[f"layers.{i}.ffn.w2.weight"]
418+
state_dict[f"layers.{i}.ffn.w3"] = weights[f"layers.{i}.ffn.w3.weight"]
419+
transformer.load_state_dict(state_dict)
420+
return transformer(in_indices)
403421

404422

405423
def run_rmsnorm(

0 commit comments

Comments
 (0)