@@ -31,36 +31,21 @@ namespace dlib
3131 vocab[i] = std::vector<uint8_t >{ static_cast <uint8_t >(i) };
3232
3333 // Initialize special tokens with sequential IDs
34- special_tokens =
35- {
36- {" <text>" , BASE_VOCAB_SIZE},
37- {" </text>" , BASE_VOCAB_SIZE + 1 },
38- {" <url>" , BASE_VOCAB_SIZE + 2 },
39- {" </url>" , BASE_VOCAB_SIZE + 3 },
40- {" <image>" , BASE_VOCAB_SIZE + 4 },
41- {" </image>" , BASE_VOCAB_SIZE + 5 },
42- {" <video>" , BASE_VOCAB_SIZE + 6 },
43- {" </video>" , BASE_VOCAB_SIZE + 7 },
44- {" <audio>" , BASE_VOCAB_SIZE + 8 },
45- {" </audio>" , BASE_VOCAB_SIZE + 9 },
46- {" <file>" , BASE_VOCAB_SIZE + 10 },
47- {" </file>" , BASE_VOCAB_SIZE + 11 },
48- {" <code>" , BASE_VOCAB_SIZE + 12 },
49- {" </code>" , BASE_VOCAB_SIZE + 13 },
50- {" <summary>" , BASE_VOCAB_SIZE + 14 },
51- {" </summary>" , BASE_VOCAB_SIZE + 15 },
52- {" <think>" , BASE_VOCAB_SIZE + 16 },
53- {" </think>" , BASE_VOCAB_SIZE + 17 },
54- {" <start>" , BASE_VOCAB_SIZE + 18 },
55- {" <end>" , BASE_VOCAB_SIZE + 19 },
56- {" <user>" , BASE_VOCAB_SIZE + 20 },
57- {" <bot>" , BASE_VOCAB_SIZE + 21 },
58- {" <system>" , BASE_VOCAB_SIZE + 22 },
59- {" <question>" , BASE_VOCAB_SIZE + 23 },
60- {" <answer>" , BASE_VOCAB_SIZE + 24 },
61- {" <search>" , BASE_VOCAB_SIZE + 25 },
62- {" <unk>" , BASE_VOCAB_SIZE + 26 },
63- {" <pad>" , BASE_VOCAB_SIZE + 27 }
34+ special_tokens = {
35+ {" <text>" , BASE_VOCAB_SIZE}, {" </text>" , BASE_VOCAB_SIZE + 1 },
36+ {" <url>" , BASE_VOCAB_SIZE + 2 }, {" </url>" , BASE_VOCAB_SIZE + 3 },
37+ {" <image>" , BASE_VOCAB_SIZE + 4 }, {" </image>" , BASE_VOCAB_SIZE + 5 },
38+ {" <video>" , BASE_VOCAB_SIZE + 6 }, {" </video>" , BASE_VOCAB_SIZE + 7 },
39+ {" <audio>" , BASE_VOCAB_SIZE + 8 }, {" </audio>" , BASE_VOCAB_SIZE + 9 },
40+ {" <file>" , BASE_VOCAB_SIZE + 10 }, {" </file>" , BASE_VOCAB_SIZE + 11 },
41+ {" <code>" , BASE_VOCAB_SIZE + 12 }, {" </code>" , BASE_VOCAB_SIZE + 13 },
42+ {" <summary>" , BASE_VOCAB_SIZE + 14 }, {" </summary>" , BASE_VOCAB_SIZE + 15 },
43+ {" <think>" , BASE_VOCAB_SIZE + 16 }, {" </think>" , BASE_VOCAB_SIZE + 17 },
44+ {" <start>" , BASE_VOCAB_SIZE + 18 }, {" <end>" , BASE_VOCAB_SIZE + 19 },
45+ {" <user>" , BASE_VOCAB_SIZE + 20 }, {" <bot>" , BASE_VOCAB_SIZE + 21 },
46+ {" <system>" , BASE_VOCAB_SIZE + 22 }, {" <question>" , BASE_VOCAB_SIZE + 23 },
47+ {" <answer>" , BASE_VOCAB_SIZE + 24 }, {" <search>" , BASE_VOCAB_SIZE + 25 },
48+ {" <unk>" , BASE_VOCAB_SIZE + 26 }, {" <pad>" , BASE_VOCAB_SIZE + 27 }
6449 };
6550
6651 // Initialize the vector of special token IDs
@@ -79,6 +64,7 @@ namespace dlib
7964
8065 // Convert text to byte IDs
8166 std::vector<int > ids;
67+ ids.reserve (text.size ());
8268 for (char c : text) ids.push_back (static_cast <uint8_t >(c));
8369
8470 // Perform BPE merges
@@ -88,38 +74,31 @@ namespace dlib
8874
8975 // Find the most frequent pair that does not exceed MAX_TOKEN_LENGTH
9076 auto pair = get_most_frequent_pair (stats);
77+ if (pair.first == -1 ) break ;
9178
9279 // Check if the resulting token would exceed MAX_TOKEN_LENGTH
9380 size_t new_token_length = vocab[pair.first ].size () + vocab[pair.second ].size ();
9481 if (new_token_length > MAX_TOKEN_LENGTH) {
9582 if (verbose)
96- {
97- std::cout << " \r "
98- << std::setw (100 ) << std::flush
99- << " \r skipping merge " << std::to_string (i + 1 ) << " /" << std::to_string (num_merges) << " : ("
100- << std::to_string (pair.first ) << " ," << std::to_string (pair.second ) << " ) -> new token length "
101- << std::to_string (new_token_length) << " exceeds limit of " << std::to_string (MAX_TOKEN_LENGTH)
102- << std::flush;
103- }
83+ std::cout << " \r " << std::setw (100 ) << std::flush << " \r [skip] merge " << (i + 1 )
84+ << " : token too long: " << new_token_length << " /" << MAX_TOKEN_LENGTH << std::flush;
10485 continue ; // Skip this merge
10586 }
10687
107- int idx = (BASE_VOCAB_SIZE + (int )special_tokens.size ()) + i;
108- ids = merge (ids, pair, idx);
109- merges[pair] = idx;
110- vocab[idx].insert (vocab[idx].end (), vocab[pair.first ].begin (), vocab[pair.first ].end ());
111- vocab[idx].insert (vocab[idx].end (), vocab[pair.second ].begin (), vocab[pair.second ].end ());
88+ int new_id = current_base + i;
89+ merges[pair] = new_id;
90+
91+ std::vector<uint8_t >& new_token = vocab[new_id];
92+ new_token.reserve (new_token_length);
93+ new_token.insert (new_token.end (), vocab[pair.first ].begin (), vocab[pair.first ].end ());
94+ new_token.insert (new_token.end (), vocab[pair.second ].begin (), vocab[pair.second ].end ());
95+
96+ ids = merge (ids, pair, new_id);
11297
11398 if (verbose)
114- {
115- std::cout << " \r "
116- << std::setw (100 ) << std::flush
117- << " \r merge " << std::to_string (i + 1 ) << " /" << std::to_string (num_merges) << " : ("
118- << std::to_string (pair.first ) << " ," << std::to_string (pair.second ) << " ) -> " << std::to_string (idx)
119- << " (" << bytes_to_string (vocab[idx]) << " ) had "
120- << std::to_string (stats[pair]) << " occurrences"
121- << std::endl;
122- }
99+ std::cout << " \r " << std::setw (100 ) << std::flush << " \r [merge] " << (i + 1 ) << " /" << num_merges
100+ << " : (" << pair.first << " ," << pair.second << " ) -> " << new_id
101+ << " (" << bytes_to_string (vocab[new_id]) << " )" << std::endl;
123102 }
124103 }
125104
0 commit comments