88pre_tokenization_pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
99num_processes = 4
1010
11- def process_chunk (start : int , end : int , input_path : str , special_tokens : list [str ]) -> dict :
12- words_counts = defaultdict (int )
13- with open (input_path , "rb" ) as f :
14- f .seek (start )
15- chunk = f .read (end - start )
16-
17- sentences = re .split (b"|" .join (re .escape (token ).encode ("utf-8" ) for token in special_tokens ), chunk )
18-
19- words_counts = defaultdict (int )
20- for sentence in sentences :
21- matches = re .finditer (pre_tokenization_pattern , sentence .decode ('utf-8' , errors = 'replace' ))
22- for match in matches :
23- word_bytes = tuple (match .group ().encode ("utf-8" ))
24- words_counts [word_bytes ] += 1
25-
26- return dict (words_counts )
27-
2811# return vocab and merges
2912def train_bpe (
3013 input_path : str ,
@@ -40,17 +23,12 @@ def train_bpe(
4023 vocab [next_index ] = token .encode ("utf-8" )
4124 next_index += 1
4225
43- merges : list [tuple [bytes , bytes ]] = []
44-
4526 # --------------------------------
4627 # | Parallelizing Pre tokenziation |
4728 # --------------------------------
4829 parallel_pre_tokenziation_start = time .time ()
49- all_boundaries = set ([0 ])
5030 with open (input_path , "rb" ) as f :
51- for special_token in special_tokens :
52- all_boundaries .update (find_chunk_boundaries (f , num_processes , special_token .encode ("utf-8" )))
53- boundaries = sorted (all_boundaries )
31+ boundaries = find_chunk_boundaries (f , num_processes , "<|endoftext|>" .encode ("utf-8" ))
5432
5533 chunk_pairs = list (zip (boundaries [:- 1 ], boundaries [1 :]))
5634
@@ -90,55 +68,52 @@ def train_bpe(
9068 # words_counts[word_bytes] += 1
9169 # print(f"Normal Pre Tokenization cost: {time.time() - normal_pre_tokenziation_start}")
9270
93- # next_index = len(vocab)
71+ merges : list [ tuple [ bytes , bytes ]] = []
9472
9573 # -----------------
9674 # | Optimized Merge |
9775 # -----------------
9876 merge_start = time .time ()
9977 pair_counts = defaultdict (int )
78+ pair_to_word_bytes = defaultdict (set )
10079 for word_bytes , count in words_counts .items ():
10180 pairs = list (zip (word_bytes [:- 1 ], word_bytes [1 :]))
10281 for pair in pairs :
10382 pair_counts [pair ] += count
83+ pair_to_word_bytes [pair ].add (word_bytes )
10484 while next_index < vocab_size :
10585 max_pair = max (pair_counts , key = lambda x : (pair_counts [x ], (vocab [x [0 ]], vocab [x [1 ]])))
10686 index1 , index2 = max_pair
10787 merges .append ((vocab [index1 ], vocab [index2 ]))
10888 vocab [next_index ] = vocab [index1 ] + vocab [index2 ]
109-
110- new_words_counts = defaultdict (int )
111- new_pair_counts = defaultdict (int , pair_counts )
112- for word_bytes , count in words_counts .items ():
113- old_pairs = list (zip (word_bytes [:- 1 ], word_bytes [1 :]))
114-
115- if max_pair not in old_pairs :
116- new_words_counts [word_bytes ] += count
117- continue
118-
89+ affected_word_bytes = pair_to_word_bytes [max_pair ].copy ()
90+
91+ for affected in affected_word_bytes :
92+ count = words_counts [affected ]
11993 new_word_bytes = []
12094 i = 0
121- while i < len (word_bytes ):
122- if i < len (word_bytes ) - 1 and word_bytes [i ] == index1 and word_bytes [i + 1 ] == index2 :
95+ while i < len (affected ):
96+ if i < len (affected ) - 1 and affected [i ] == index1 and affected [i + 1 ] == index2 :
12397 new_word_bytes .append (next_index )
12498 i += 2
12599 else :
126- new_word_bytes .append (word_bytes [i ])
100+ new_word_bytes .append (affected [i ])
127101 i += 1
128102 new_word = tuple (new_word_bytes )
129- new_words_counts [new_word ] += count
103+ words_counts [new_word ] += count
130104
131- for pair in old_pairs :
132- new_pair_counts [pair ] -= count
133- if new_pair_counts [pair ] <= 0 :
134- del new_pair_counts [pair ]
105+ for pair in zip (affected [:- 1 ], affected [1 :]):
106+ pair_counts [pair ] -= count
107+ pair_to_word_bytes [pair ].discard (affected )
108+ if pair_counts [pair ] <= 0 :
109+ del pair_counts [pair ]
110+ del pair_to_word_bytes [pair ]
135111
136112 new_pairs = list (zip (new_word [:- 1 ], new_word [1 :]))
137113 for pair in new_pairs :
138- new_pair_counts [pair ] += count
114+ pair_counts [pair ] += count
115+ pair_to_word_bytes [pair ].add (new_word )
139116
140- words_counts = new_words_counts
141- pair_counts = new_pair_counts
142117 next_index += 1
143118 print (f"Merge cost: { time .time () - merge_start } " )
144119
@@ -174,3 +149,20 @@ def train_bpe(
174149 # next_index += 1
175150 # print(f"Merge cost: {time.time() - merge_start}")
176151 return vocab , merges
152+
153+ def process_chunk (start : int , end : int , input_path : str , special_tokens : list [str ]) -> dict :
154+ words_counts = defaultdict (int )
155+ with open (input_path , "rb" ) as f :
156+ f .seek (start )
157+ chunk = f .read (end - start )
158+
159+ sentences = re .split (b"|" .join (re .escape (token ).encode ("utf-8" ) for token in special_tokens ), chunk )
160+
161+ words_counts = defaultdict (int )
162+ for sentence in sentences :
163+ matches = re .finditer (pre_tokenization_pattern , sentence .decode ('utf-8' , errors = 'replace' ))
164+ for match in matches :
165+ word_bytes = tuple (match .group ().encode ("utf-8" ))
166+ words_counts [word_bytes ] += 1
167+
168+ return words_counts
0 commit comments