Skip to content

Commit 6acd10a

Browse files
committed
feat(bpe): optimize the merge process to make time of training on TinyStory-train < 120s
1 parent e21961f commit 6acd10a

1 file changed

Lines changed: 37 additions & 45 deletions

File tree

cs336_basics/train_bpe.py

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,6 @@
88
pre_tokenization_pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
99
num_processes = 4
1010

11-
def process_chunk(start: int, end: int, input_path: str, special_tokens: list[str]) -> dict:
12-
words_counts = defaultdict(int)
13-
with open(input_path, "rb") as f:
14-
f.seek(start)
15-
chunk = f.read(end - start)
16-
17-
sentences = re.split(b"|".join(re.escape(token).encode("utf-8") for token in special_tokens), chunk)
18-
19-
words_counts = defaultdict(int)
20-
for sentence in sentences:
21-
matches = re.finditer(pre_tokenization_pattern, sentence.decode('utf-8', errors='replace'))
22-
for match in matches:
23-
word_bytes = tuple(match.group().encode("utf-8"))
24-
words_counts[word_bytes] += 1
25-
26-
return dict(words_counts)
27-
2811
# return vocab and merges
2912
def train_bpe(
3013
input_path: str,
@@ -40,17 +23,12 @@ def train_bpe(
4023
vocab[next_index] = token.encode("utf-8")
4124
next_index += 1
4225

43-
merges: list[tuple[bytes, bytes]] = []
44-
4526
# --------------------------------
4627
# | Parallelizing Pre tokenziation |
4728
# --------------------------------
4829
parallel_pre_tokenziation_start = time.time()
49-
all_boundaries = set([0])
5030
with open(input_path, "rb") as f:
51-
for special_token in special_tokens:
52-
all_boundaries.update(find_chunk_boundaries(f, num_processes, special_token.encode("utf-8")))
53-
boundaries = sorted(all_boundaries)
31+
boundaries = find_chunk_boundaries(f, num_processes, "<|endoftext|>".encode("utf-8"))
5432

5533
chunk_pairs = list(zip(boundaries[:-1], boundaries[1:]))
5634

@@ -90,55 +68,52 @@ def train_bpe(
9068
# words_counts[word_bytes] += 1
9169
# print(f"Normal Pre Tokenization cost: {time.time() - normal_pre_tokenziation_start}")
9270

93-
# next_index = len(vocab)
71+
merges: list[tuple[bytes, bytes]] = []
9472

9573
# -----------------
9674
# | Optimized Merge |
9775
# -----------------
9876
merge_start = time.time()
9977
pair_counts = defaultdict(int)
78+
pair_to_word_bytes = defaultdict(set)
10079
for word_bytes, count in words_counts.items():
10180
pairs = list(zip(word_bytes[:-1], word_bytes[1:]))
10281
for pair in pairs:
10382
pair_counts[pair] += count
83+
pair_to_word_bytes[pair].add(word_bytes)
10484
while next_index < vocab_size:
10585
max_pair = max(pair_counts, key=lambda x: (pair_counts[x], (vocab[x[0]], vocab[x[1]])))
10686
index1, index2 = max_pair
10787
merges.append((vocab[index1], vocab[index2]))
10888
vocab[next_index] = vocab[index1] + vocab[index2]
109-
110-
new_words_counts = defaultdict(int)
111-
new_pair_counts = defaultdict(int, pair_counts)
112-
for word_bytes, count in words_counts.items():
113-
old_pairs = list(zip(word_bytes[:-1], word_bytes[1:]))
114-
115-
if max_pair not in old_pairs:
116-
new_words_counts[word_bytes] += count
117-
continue
118-
89+
affected_word_bytes = pair_to_word_bytes[max_pair].copy()
90+
91+
for affected in affected_word_bytes:
92+
count = words_counts[affected]
11993
new_word_bytes = []
12094
i = 0
121-
while i < len(word_bytes):
122-
if i < len(word_bytes) - 1 and word_bytes[i] == index1 and word_bytes[i + 1] == index2:
95+
while i < len(affected):
96+
if i < len(affected) - 1 and affected[i] == index1 and affected[i + 1] == index2:
12397
new_word_bytes.append(next_index)
12498
i += 2
12599
else:
126-
new_word_bytes.append(word_bytes[i])
100+
new_word_bytes.append(affected[i])
127101
i += 1
128102
new_word = tuple(new_word_bytes)
129-
new_words_counts[new_word] += count
103+
words_counts[new_word] += count
130104

131-
for pair in old_pairs:
132-
new_pair_counts[pair] -= count
133-
if new_pair_counts[pair] <= 0:
134-
del new_pair_counts[pair]
105+
for pair in zip(affected[:-1], affected[1:]):
106+
pair_counts[pair] -= count
107+
pair_to_word_bytes[pair].discard(affected)
108+
if pair_counts[pair] <= 0:
109+
del pair_counts[pair]
110+
del pair_to_word_bytes[pair]
135111

136112
new_pairs = list(zip(new_word[:-1], new_word[1:]))
137113
for pair in new_pairs:
138-
new_pair_counts[pair] += count
114+
pair_counts[pair] += count
115+
pair_to_word_bytes[pair].add(new_word)
139116

140-
words_counts = new_words_counts
141-
pair_counts = new_pair_counts
142117
next_index += 1
143118
print(f"Merge cost: {time.time() - merge_start}")
144119

@@ -174,3 +149,20 @@ def train_bpe(
174149
# next_index += 1
175150
# print(f"Merge cost: {time.time() - merge_start}")
176151
return vocab, merges
152+
153+
def process_chunk(start: int, end: int, input_path: str, special_tokens: list[str]) -> dict:
154+
words_counts = defaultdict(int)
155+
with open(input_path, "rb") as f:
156+
f.seek(start)
157+
chunk = f.read(end - start)
158+
159+
sentences = re.split(b"|".join(re.escape(token).encode("utf-8") for token in special_tokens), chunk)
160+
161+
words_counts = defaultdict(int)
162+
for sentence in sentences:
163+
matches = re.finditer(pre_tokenization_pattern, sentence.decode('utf-8', errors='replace'))
164+
for match in matches:
165+
word_bytes = tuple(match.group().encode("utf-8"))
166+
words_counts[word_bytes] += 1
167+
168+
return words_counts

0 commit comments

Comments
 (0)