|
| 1 | +import torch |
| 2 | +import argparse |
| 3 | +from cs336_basics.tokenizer import BPETokenizer |
| 4 | + |
| 5 | + |
| 6 | +def load_tokenizer(): |
| 7 | + tokenizer = BPETokenizer.from_files( |
| 8 | + vocab_filepath="data/tinystories_valid_tokenizer/tinystories_vocab.json", |
| 9 | + merges_filepath="data/tinystories_valid_tokenizer/tinystories_merges.txt", |
| 10 | + special_tokens=["<|endoftext|>"], |
| 11 | + ) |
| 12 | + return tokenizer |
| 13 | + |
| 14 | + |
| 15 | +def load_model(checkpoint_path: str, device: str): |
| 16 | + from cs336_basics.transformer import Transformer |
| 17 | + |
| 18 | + ckpt = torch.load(checkpoint_path) |
| 19 | + config = ckpt["model_state"] |
| 20 | + |
| 21 | + model = Transformer( |
| 22 | + vocab_size=10000, |
| 23 | + context_length=256, |
| 24 | + num_layers=4, |
| 25 | + d_model=512, |
| 26 | + num_heads=16, |
| 27 | + d_ff=1344, |
| 28 | + rope_theta=10000.0, |
| 29 | + device=device, |
| 30 | + ) |
| 31 | + |
| 32 | + model.load_state_dict(config) |
| 33 | + model.to(device) |
| 34 | + return model |
| 35 | + |
| 36 | + |
| 37 | +def sample_next_token(logits, temperature=1.0, top_p=1.0): |
| 38 | + from cs336_basics.softmax import softmax |
| 39 | + |
| 40 | + if temperature <= 0: |
| 41 | + return int(torch.argmax(logits).item()) |
| 42 | + |
| 43 | + logits = logits / temperature |
| 44 | + probs = softmax(logits, -1) |
| 45 | + |
| 46 | + if top_p is None or top_p >= 1.0: |
| 47 | + return int(torch.multinomial(probs, num_samples=1).item()) |
| 48 | + |
| 49 | + sorted_probs, sorted_idx = torch.sort(probs, descending=True) |
| 50 | + cumulative = torch.cumsum(sorted_probs, dim=-1) |
| 51 | + |
| 52 | + mask = cumulative <= top_p |
| 53 | + if not torch.any(mask): |
| 54 | + mask[0] = True |
| 55 | + |
| 56 | + cutoff = torch.nonzero(mask)[-1].item() |
| 57 | + mask[: cutoff + 1] = True |
| 58 | + |
| 59 | + truncated_probs = sorted_probs * mask |
| 60 | + truncated_probs /= truncated_probs.sum() |
| 61 | + |
| 62 | + sampled = torch.multinomial(truncated_probs, 1) |
| 63 | + next_id = sorted_idx[sampled] |
| 64 | + |
| 65 | + return int(next_id.item()) |
| 66 | + |
| 67 | + |
| 68 | +@torch.no_grad() |
| 69 | +def decode( |
| 70 | + model: torch.nn.Module, |
| 71 | + tokenizer: BPETokenizer, |
| 72 | + prompt_ids: torch.Tensor, |
| 73 | + max_tokens: int, |
| 74 | + device, |
| 75 | + temperature=1.0, |
| 76 | + top_p=1.0, |
| 77 | +): |
| 78 | + model.eval() |
| 79 | + ids = prompt_ids.to(device) |
| 80 | + |
| 81 | + eos_id = tokenizer.vocab_reverse[b"<|endoftext|>"] |
| 82 | + |
| 83 | + for _ in range(max_tokens): |
| 84 | + logits = model(ids.unsqueeze(0)) |
| 85 | + last_logits = logits[0, -1] |
| 86 | + |
| 87 | + next_id = sample_next_token(last_logits, temperature=temperature, top_p=top_p) |
| 88 | + next_id_tensor = torch.tensor([next_id], dtype=torch.long, device=device) |
| 89 | + |
| 90 | + ids = torch.cat([ids, next_id_tensor], dim=0) |
| 91 | + |
| 92 | + if next_id == eos_id: |
| 93 | + break |
| 94 | + return ids |
| 95 | + |
| 96 | + |
| 97 | +if __name__ == "__main__": |
| 98 | + parser = argparse.ArgumentParser() |
| 99 | + parser.add_argument("--checkpoint", type=str, required=True) |
| 100 | + parser.add_argument("--prompt", type=str, required=True) |
| 101 | + parser.add_argument("--max-tokens", type=int, default=128) |
| 102 | + parser.add_argument("--temperature", type=float, default=1.0) |
| 103 | + parser.add_argument("--top-p", type=float, default=1.0) |
| 104 | + parser.add_argument("--device", type=str, default="mps") |
| 105 | + |
| 106 | + args = parser.parse_args() |
| 107 | + device = args.device |
| 108 | + |
| 109 | + tokenizer = load_tokenizer() |
| 110 | + |
| 111 | + prompt_ids = tokenizer.encode(args.prompt) |
| 112 | + prompt_ids = torch.tensor(prompt_ids, dtype=torch.long) |
| 113 | + |
| 114 | + model = load_model(args.checkpoint, device) |
| 115 | + full_ids = decode( |
| 116 | + model=model, |
| 117 | + tokenizer=tokenizer, |
| 118 | + prompt_ids=prompt_ids, |
| 119 | + max_tokens=args.max_tokens, |
| 120 | + device=device, |
| 121 | + temperature=args.temperature, |
| 122 | + top_p=args.top_p, |
| 123 | + ) |
| 124 | + |
| 125 | + text = tokenizer.decode(full_ids.tolist()) |
| 126 | + print(text) |
0 commit comments