Skip to content

Commit 593f5fc

Browse files
committed
feat: support decoding
1 parent 208ef82 commit 593f5fc

5 files changed

Lines changed: 206 additions & 7 deletions

File tree

cs336_basics/rope.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import torch
22

3+
34
class RoPE(torch.nn.Module):
45
def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
56
super().__init__()
67
self.theta = theta
78
self.d_k = d_k
89
self.max_seq_len = max_seq_len
910

10-
d_k_half= d_k // 2
11-
inv_freq = 1 / (theta ** (torch.arange(0, d_k_half, device=device, dtype=torch.float32) / d_k_half))
11+
d_k_half = d_k // 2
12+
inv_freq = 1 / (
13+
theta
14+
** (
15+
torch.arange(0, d_k_half, device=device, dtype=torch.float32) / d_k_half
16+
)
17+
)
1218
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
1319
freqs = torch.einsum("i,j->ij", t, inv_freq)
1420
cos = torch.cos(freqs)
@@ -19,6 +25,10 @@ def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
1925
def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
2026
# x: (..., seq_len, d_k)
2127
# token_position: (..., seq_len)
28+
while token_positions.dim() < x.dim() - 1:
29+
# Match missing batch-like dims (e.g. head dimension) so broadcasting works
30+
token_positions = token_positions.unsqueeze(-2)
31+
2232
cos = self.cos[token_positions]
2333
sin = self.sin[token_positions]
2434

@@ -32,4 +42,4 @@ def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tenso
3242
x2_rot = x1 * sin + x2 * cos
3343

3444
out = torch.stack([x1_rot, x2_rot], dim=-1).reshape(*x.shape[:-2], d_k)
35-
return out
45+
return out

cs336_basics/tokenizer.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,21 @@
33

44
pre_tokenization_pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
55

6+
def _bpe_worker(args):
7+
tokenizer, chunk = args
8+
return tokenizer.encode(chunk)
9+
10+
def _parse_merge_line(line: str):
11+
assert line.startswith("['") and line.endswith("']")
12+
13+
inner = line[2:-2]
14+
15+
split_index = inner.find("', '")
16+
left = inner[:split_index]
17+
right = inner[split_index + 4:]
18+
19+
return left, right
20+
621
class BPETokenizer:
722
def __init__(self, vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]], special_tokens: list[str] | None=None) -> None:
823
self.vocab = vocab
@@ -21,8 +36,35 @@ def __init__(self, vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]], s
2136
self.byte_to_token[byte_val] = token_id
2237
self.special_tokens = special_tokens
2338

24-
def from_files(cls, vocab_filepath, merge_filepath, special_tokens=None):
25-
return NotImplemented
39+
@classmethod
40+
def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None):
41+
import json
42+
from tests.common import gpt2_bytes_to_unicode
43+
44+
byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
45+
46+
with open(vocab_filepath, "r", encoding="utf-8") as f:
47+
raw_vocab = json.load(f)
48+
49+
vocab = {}
50+
for unicode_token, token_id in raw_vocab.items():
51+
token_bytes = bytes([byte_decoder[c] for c in unicode_token])
52+
vocab[int(token_id)] = token_bytes
53+
54+
merges = []
55+
with open(merges_filepath, "r", encoding="utf-8") as f:
56+
for line in f:
57+
line = line.strip()
58+
if not line:
59+
continue
60+
61+
left_unicode, right_unicode = _parse_merge_line(line)
62+
left_bytes = bytes([byte_decoder[c] for c in left_unicode])
63+
right_bytes = bytes([byte_decoder[c] for c in right_unicode])
64+
65+
merges.append((left_bytes, right_bytes))
66+
67+
return cls(vocab, merges, special_tokens)
2668

2769
def encode(self, text: str) -> list[int]:
2870
splitted_text = []
@@ -61,6 +103,25 @@ def encode_iterable(self, iterable: Iterable[str]) -> Iterable[int]:
61103
for text in iterable:
62104
yield from self.encode(text)
63105

106+
def encode_parallel(self, text: str, num_workers: int = 4) -> list[int]:
107+
import multiprocessing as mp
108+
109+
length = len(text)
110+
if length == 0:
111+
return []
112+
113+
chunk_size = max(1, length // num_workers)
114+
chunks = [text[i:i + chunk_size] for i in range(0, length, chunk_size)]
115+
116+
with mp.Pool(processes=num_workers) as pool:
117+
results = pool.map(_bpe_worker, [(self, c) for c in chunks])
118+
119+
merged: list[int] = []
120+
for r in results:
121+
merged.extend(r)
122+
return merged
123+
124+
64125
def decode(self, ids: list[int]) -> str:
65126
list = []
66127
for id in ids:

decoding.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
import argparse
3+
from cs336_basics.tokenizer import BPETokenizer
4+
5+
6+
def load_tokenizer():
7+
tokenizer = BPETokenizer.from_files(
8+
vocab_filepath="data/tinystories_valid_tokenizer/tinystories_vocab.json",
9+
merges_filepath="data/tinystories_valid_tokenizer/tinystories_merges.txt",
10+
special_tokens=["<|endoftext|>"],
11+
)
12+
return tokenizer
13+
14+
15+
def load_model(checkpoint_path: str, device: str):
16+
from cs336_basics.transformer import Transformer
17+
18+
ckpt = torch.load(checkpoint_path)
19+
config = ckpt["model_state"]
20+
21+
model = Transformer(
22+
vocab_size=10000,
23+
context_length=256,
24+
num_layers=4,
25+
d_model=512,
26+
num_heads=16,
27+
d_ff=1344,
28+
rope_theta=10000.0,
29+
device=device,
30+
)
31+
32+
model.load_state_dict(config)
33+
model.to(device)
34+
return model
35+
36+
37+
def sample_next_token(logits, temperature=1.0, top_p=1.0):
38+
from cs336_basics.softmax import softmax
39+
40+
if temperature <= 0:
41+
return int(torch.argmax(logits).item())
42+
43+
logits = logits / temperature
44+
probs = softmax(logits, -1)
45+
46+
if top_p is None or top_p >= 1.0:
47+
return int(torch.multinomial(probs, num_samples=1).item())
48+
49+
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
50+
cumulative = torch.cumsum(sorted_probs, dim=-1)
51+
52+
mask = cumulative <= top_p
53+
if not torch.any(mask):
54+
mask[0] = True
55+
56+
cutoff = torch.nonzero(mask)[-1].item()
57+
mask[: cutoff + 1] = True
58+
59+
truncated_probs = sorted_probs * mask
60+
truncated_probs /= truncated_probs.sum()
61+
62+
sampled = torch.multinomial(truncated_probs, 1)
63+
next_id = sorted_idx[sampled]
64+
65+
return int(next_id.item())
66+
67+
68+
@torch.no_grad()
69+
def decode(
70+
model: torch.nn.Module,
71+
tokenizer: BPETokenizer,
72+
prompt_ids: torch.Tensor,
73+
max_tokens: int,
74+
device,
75+
temperature=1.0,
76+
top_p=1.0,
77+
):
78+
model.eval()
79+
ids = prompt_ids.to(device)
80+
81+
eos_id = tokenizer.vocab_reverse[b"<|endoftext|>"]
82+
83+
for _ in range(max_tokens):
84+
logits = model(ids.unsqueeze(0))
85+
last_logits = logits[0, -1]
86+
87+
next_id = sample_next_token(last_logits, temperature=temperature, top_p=top_p)
88+
next_id_tensor = torch.tensor([next_id], dtype=torch.long, device=device)
89+
90+
ids = torch.cat([ids, next_id_tensor], dim=0)
91+
92+
if next_id == eos_id:
93+
break
94+
return ids
95+
96+
97+
if __name__ == "__main__":
98+
parser = argparse.ArgumentParser()
99+
parser.add_argument("--checkpoint", type=str, required=True)
100+
parser.add_argument("--prompt", type=str, required=True)
101+
parser.add_argument("--max-tokens", type=int, default=128)
102+
parser.add_argument("--temperature", type=float, default=1.0)
103+
parser.add_argument("--top-p", type=float, default=1.0)
104+
parser.add_argument("--device", type=str, default="mps")
105+
106+
args = parser.parse_args()
107+
device = args.device
108+
109+
tokenizer = load_tokenizer()
110+
111+
prompt_ids = tokenizer.encode(args.prompt)
112+
prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
113+
114+
model = load_model(args.checkpoint, device)
115+
full_ids = decode(
116+
model=model,
117+
tokenizer=tokenizer,
118+
prompt_ids=prompt_ids,
119+
max_tokens=args.max_tokens,
120+
device=device,
121+
temperature=args.temperature,
122+
top_p=args.top_p,
123+
)
124+
125+
text = tokenizer.decode(full_ids.tolist())
126+
print(text)

prepare_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def main():
2929
print("Trained BPE")
3030

3131
tokenizer = BPETokenizer(vocab, merges, special_tokens=args.special)
32-
ids = tokenizer.encode(text)
32+
ids = tokenizer.encode_parallel(text, num_workers=14)
3333
print("Tokenized")
3434

3535
arr = np.array(ids, dtype=np.uint16)

tests/test_train_bpe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_train_bpe_special_tokens(snapshot):
9090
import pathlib
9191
import pytest
9292
data_folder = (pathlib.Path(__file__).resolve().parent.parent) / "data"
93+
TINYSTORIES_VALID_DIR = data_folder / "tinystories_valid_tokenizer"
9394
TINYSTORIES_ARTIFACTS_DIR = data_folder / "tinystories_tokenizer"
9495

9596

@@ -126,10 +127,11 @@ def _save_tokenizer_artifacts(vocab, merges, output_dir):
126127
def test_train_bpe_on_tiny_story_valid():
127128
start_time = time.time()
128129
input_path = data_folder / "TinyStoriesV2-GPT4-valid.txt"
129-
_, _ = run_train_bpe(
130+
vocab, merges = run_train_bpe(
130131
input_path=input_path,
131132
vocab_size=10000,
132133
special_tokens=["<|endoftext|>"])
134+
_save_tokenizer_artifacts(vocab, merges, TINYSTORIES_VALID_DIR)
133135
end_time = time.time()
134136

135137
assert(end_time - start_time <= 120)

0 commit comments

Comments
 (0)