Skip to content

Commit a4dac0b

Browse files
committed
Fixing break condition in training method.
1 parent d4bf94b commit a4dac0b

1 file changed

Lines changed: 31 additions & 52 deletions

File tree

dlib/tokenizer/bpe_tokenizer.h

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -31,36 +31,21 @@ namespace dlib
3131
vocab[i] = std::vector<uint8_t>{ static_cast<uint8_t>(i) };
3232

3333
// Initialize special tokens with sequential IDs
34-
special_tokens =
35-
{
36-
{"<text>", BASE_VOCAB_SIZE},
37-
{"</text>", BASE_VOCAB_SIZE + 1},
38-
{"<url>", BASE_VOCAB_SIZE + 2},
39-
{"</url>", BASE_VOCAB_SIZE + 3},
40-
{"<image>", BASE_VOCAB_SIZE + 4},
41-
{"</image>", BASE_VOCAB_SIZE + 5},
42-
{"<video>", BASE_VOCAB_SIZE + 6},
43-
{"</video>", BASE_VOCAB_SIZE + 7},
44-
{"<audio>", BASE_VOCAB_SIZE + 8},
45-
{"</audio>", BASE_VOCAB_SIZE + 9},
46-
{"<file>", BASE_VOCAB_SIZE + 10},
47-
{"</file>", BASE_VOCAB_SIZE + 11},
48-
{"<code>", BASE_VOCAB_SIZE + 12},
49-
{"</code>", BASE_VOCAB_SIZE + 13},
50-
{"<summary>", BASE_VOCAB_SIZE + 14},
51-
{"</summary>", BASE_VOCAB_SIZE + 15},
52-
{"<think>", BASE_VOCAB_SIZE + 16},
53-
{"</think>", BASE_VOCAB_SIZE + 17},
54-
{"<start>", BASE_VOCAB_SIZE + 18},
55-
{"<end>", BASE_VOCAB_SIZE + 19},
56-
{"<user>", BASE_VOCAB_SIZE + 20},
57-
{"<bot>", BASE_VOCAB_SIZE + 21},
58-
{"<system>", BASE_VOCAB_SIZE + 22},
59-
{"<question>", BASE_VOCAB_SIZE + 23},
60-
{"<answer>", BASE_VOCAB_SIZE + 24},
61-
{"<search>", BASE_VOCAB_SIZE + 25},
62-
{"<unk>", BASE_VOCAB_SIZE + 26},
63-
{"<pad>", BASE_VOCAB_SIZE + 27}
34+
special_tokens = {
35+
{"<text>", BASE_VOCAB_SIZE}, {"</text>", BASE_VOCAB_SIZE + 1},
36+
{"<url>", BASE_VOCAB_SIZE + 2}, {"</url>", BASE_VOCAB_SIZE + 3},
37+
{"<image>", BASE_VOCAB_SIZE + 4}, {"</image>", BASE_VOCAB_SIZE + 5},
38+
{"<video>", BASE_VOCAB_SIZE + 6}, {"</video>", BASE_VOCAB_SIZE + 7},
39+
{"<audio>", BASE_VOCAB_SIZE + 8}, {"</audio>", BASE_VOCAB_SIZE + 9},
40+
{"<file>", BASE_VOCAB_SIZE + 10}, {"</file>", BASE_VOCAB_SIZE + 11},
41+
{"<code>", BASE_VOCAB_SIZE + 12}, {"</code>", BASE_VOCAB_SIZE + 13},
42+
{"<summary>", BASE_VOCAB_SIZE + 14}, {"</summary>", BASE_VOCAB_SIZE + 15},
43+
{"<think>", BASE_VOCAB_SIZE + 16}, {"</think>", BASE_VOCAB_SIZE + 17},
44+
{"<start>", BASE_VOCAB_SIZE + 18}, {"<end>", BASE_VOCAB_SIZE + 19},
45+
{"<user>", BASE_VOCAB_SIZE + 20}, {"<bot>", BASE_VOCAB_SIZE + 21},
46+
{"<system>", BASE_VOCAB_SIZE + 22}, {"<question>", BASE_VOCAB_SIZE + 23},
47+
{"<answer>", BASE_VOCAB_SIZE + 24}, {"<search>", BASE_VOCAB_SIZE + 25},
48+
{"<unk>", BASE_VOCAB_SIZE + 26}, {"<pad>", BASE_VOCAB_SIZE + 27}
6449
};
6550

6651
// Initialize the vector of special token IDs
@@ -79,6 +64,7 @@ namespace dlib
7964

8065
// Convert text to byte IDs
8166
std::vector<int> ids;
67+
ids.reserve(text.size());
8268
for (char c : text) ids.push_back(static_cast<uint8_t>(c));
8369

8470
// Perform BPE merges
@@ -88,38 +74,31 @@ namespace dlib
8874

8975
// Find the most frequent pair that does not exceed MAX_TOKEN_LENGTH
9076
auto pair = get_most_frequent_pair(stats);
77+
if (pair.first == -1) break;
9178

9279
// Check if the resulting token would exceed MAX_TOKEN_LENGTH
9380
size_t new_token_length = vocab[pair.first].size() + vocab[pair.second].size();
9481
if (new_token_length > MAX_TOKEN_LENGTH) {
9582
if (verbose)
96-
{
97-
std::cout << "\r"
98-
<< std::setw(100) << std::flush
99-
<< "\rskipping merge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": ("
100-
<< std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> new token length "
101-
<< std::to_string(new_token_length) << " exceeds limit of " << std::to_string(MAX_TOKEN_LENGTH)
102-
<< std::flush;
103-
}
83+
std::cout << "\r" << std::setw(100) << std::flush << "\r[skip] merge " << (i + 1)
84+
<< ": token too long: " << new_token_length << "/" << MAX_TOKEN_LENGTH << std::flush;
10485
continue; // Skip this merge
10586
}
10687

107-
int idx = (BASE_VOCAB_SIZE + (int)special_tokens.size()) + i;
108-
ids = merge(ids, pair, idx);
109-
merges[pair] = idx;
110-
vocab[idx].insert(vocab[idx].end(), vocab[pair.first].begin(), vocab[pair.first].end());
111-
vocab[idx].insert(vocab[idx].end(), vocab[pair.second].begin(), vocab[pair.second].end());
88+
int new_id = current_base + i;
89+
merges[pair] = new_id;
90+
91+
std::vector<uint8_t>& new_token = vocab[new_id];
92+
new_token.reserve(new_token_length);
93+
new_token.insert(new_token.end(), vocab[pair.first].begin(), vocab[pair.first].end());
94+
new_token.insert(new_token.end(), vocab[pair.second].begin(), vocab[pair.second].end());
95+
96+
ids = merge(ids, pair, new_id);
11297

11398
if (verbose)
114-
{
115-
std::cout << "\r"
116-
<< std::setw(100) << std::flush
117-
<< "\rmerge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": ("
118-
<< std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> " << std::to_string(idx)
119-
<< " (" << bytes_to_string(vocab[idx]) << ") had "
120-
<< std::to_string(stats[pair]) << " occurrences"
121-
<< std::endl;
122-
}
99+
std::cout << "\r" << std::setw(100) << std::flush << "\r[merge] " << (i + 1) << "/" << num_merges
100+
<< ": (" << pair.first << "," << pair.second << ") -> " << new_id
101+
<< " (" << bytes_to_string(vocab[new_id]) << ")" << std::endl;
123102
}
124103
}
125104

0 commit comments

Comments
 (0)