|
19 | 19 | from cs336_basics.softmax import softmax |
20 | 20 | from cs336_basics.scaled_dot_product_attention import scaled_dot_product_attention |
21 | 21 | from cs336_basics.multihead_self_attention import MultiHeadSelfAttention, MultiHeadSelfAttentionWithRoPE |
22 | | -from cs336_basics.transformer_block import TranformerBlock |
| 22 | +from cs336_basics.transformer_block import TransformerBlock |
| 23 | +from cs336_basics.transformer import Transformer |
23 | 24 |
|
24 | 25 | def run_linear( |
25 | 26 | d_in: int, |
@@ -305,7 +306,7 @@ class and pass that instead. |
305 | 306 | Float[Tensor, "batch sequence_length d_model"] Tensor with the output of |
306 | 307 | running the Transformer block on the input features while using RoPE. |
307 | 308 | """ |
308 | | - transformer_block = TranformerBlock(d_model, num_heads, d_ff, max_seq_len, theta, device=in_features.device, dtype=in_features.dtype) |
| 309 | + transformer_block = TransformerBlock(d_model, num_heads, d_ff, max_seq_len, theta, device=in_features.device, dtype=in_features.dtype) |
309 | 310 | transformer_block.load_state_dict({ |
310 | 311 | "mhsa.W_q": weights["attn.q_proj.weight"], |
311 | 312 | "mhsa.W_k": weights["attn.k_proj.weight"], |
@@ -399,7 +400,24 @@ def run_transformer_lm( |
399 | 400 | Float[Tensor, "batch_size sequence_length vocab_size"]: Tensor with the predicted unnormalized |
400 | 401 | next-word distribution for each token. |
401 | 402 | """ |
402 | | - raise NotImplementedError |
| 403 | + transformer = Transformer(vocab_size, context_length, num_layers, d_model, num_heads, d_ff, rope_theta, device=in_indices.device) |
| 404 | + state_dict = { |
| 405 | + 'token_embedding.weights': weights['token_embeddings.weight'], |
| 406 | + 'ln_final.g': weights['ln_final.weight'], |
| 407 | + 'lm_head.weights': weights['lm_head.weight'] |
| 408 | + } |
| 409 | + for i in range(num_layers): |
| 410 | + state_dict[f"layers.{i}.mhsa.W_q"] = weights[f'layers.{i}.attn.q_proj.weight'] |
| 411 | + state_dict[f"layers.{i}.mhsa.W_k"] = weights[f'layers.{i}.attn.k_proj.weight'] |
| 412 | + state_dict[f"layers.{i}.mhsa.W_v"] = weights[f'layers.{i}.attn.v_proj.weight'] |
| 413 | + state_dict[f"layers.{i}.mhsa.W_o"] = weights[f'layers.{i}.attn.output_proj.weight'] |
| 414 | + state_dict[f"layers.{i}.ln1.g"] = weights[f'layers.{i}.ln1.weight'] |
| 415 | + state_dict[f"layers.{i}.ln2.g"] = weights[f'layers.{i}.ln2.weight'] |
| 416 | + state_dict[f"layers.{i}.ffn.w1"] = weights[f"layers.{i}.ffn.w1.weight"] |
| 417 | + state_dict[f"layers.{i}.ffn.w2"] = weights[f"layers.{i}.ffn.w2.weight"] |
| 418 | + state_dict[f"layers.{i}.ffn.w3"] = weights[f"layers.{i}.ffn.w3.weight"] |
| 419 | + transformer.load_state_dict(state_dict) |
| 420 | + return transformer(in_indices) |
403 | 421 |
|
404 | 422 |
|
405 | 423 | def run_rmsnorm( |
|
0 commit comments