Skip to content

Commit cef6bb7

Browse files
srogmannshaofeiqi
authored andcommitted
spec : various improvements ton ngram-map + docs (ggml-org#19253)
* spec: ngram-map and reasoning chats * spec: add t_begin and t_accept * ngram-map : add internal hash map * docs : update ngram-map, add ngram-mod * docs : fix ngram-map-k * docs : differences between implementations
1 parent fb4193a commit cef6bb7

4 files changed

Lines changed: 307 additions & 31 deletions

File tree

common/ngram-map.cpp

Lines changed: 191 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77
#include <cstdio>
88
#include <sstream>
99

10+
// prime number used for LCG hash function (32 bit), it is near (sqrt(5) - 1)/2 * 2^32.
11+
#define LCG_FACTOR 2654435761UL
12+
13+
// Compute the LCG hash of a n-gram of size len at offset start.
14+
static uint32_t common_ngram_map_hash(const llama_tokens & tokens, size_t start, size_t len) {
15+
uint32_t hash = 0;
16+
for (size_t i = 0; i < len; ++i) {
17+
hash = hash * LCG_FACTOR + tokens[start + i];
18+
}
19+
return hash;
20+
}
21+
1022
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
1123
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
1224
std::ostringstream oss;
@@ -115,6 +127,100 @@ llama_tokens common_ngram_simple_draft(
115127
// maximum number of counted values of a ngram map value.
116128
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
117129

130+
void common_ngram_map_begin(
131+
common_ngram_map & map, const llama_tokens & tokens) {
132+
size_t size_begin = tokens.size();
133+
134+
LOG_DBG("%s: begin, idx_last_draft=%zu, new begin=%zu, #keys=%zu\n", __func__,
135+
map.idx_last_check, size_begin, map.keys.size());
136+
137+
size_t count_map_entries_upd = 0;
138+
if (!map.key_map.empty() && size_begin < map.idx_last_check) {
139+
if (map.show_key_map_stats) {
140+
// Print statistics of hash map map_key.
141+
size_t count_nonzero = 0;
142+
uint32_t min_idx = UINT32_MAX;
143+
uint32_t max_idx = 0;
144+
for (size_t i = 0; i < map.key_map.size(); ++i) {
145+
uint32_t key_idx = map.key_map[i];
146+
if (key_idx != 0) {
147+
++count_nonzero;
148+
if (key_idx < min_idx) min_idx = key_idx;
149+
if (key_idx > max_idx) max_idx = key_idx;
150+
}
151+
}
152+
if (count_nonzero == 0) {
153+
min_idx = 0;
154+
}
155+
LOG_INF("%s: key_map stats: entries=%zu, min_idx=%u, max_idx=%u, key_map_last_idx=%u\n",
156+
__func__, count_nonzero, min_idx, max_idx, map.key_map_last_idx);
157+
}
158+
159+
// Update the map from hash to key index (clear outdated entries).
160+
for (size_t i = 0; i < map.key_map.size(); ++i) {
161+
uint32_t key_idx = map.key_map[i];
162+
if (key_idx >= map.size_last_begin) {
163+
map.key_map[i] = 0;
164+
count_map_entries_upd++;
165+
}
166+
}
167+
map.key_map_last_idx = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
168+
}
169+
170+
if (size_begin < map.idx_last_check && !map.keys.empty()) {
171+
// The next token generation will start at index size_begin.
172+
// The tokens between map.size_last_begin and size_begin are no longer valid.
173+
//
174+
// Refresh map: Remove all entries with index >= map.size_last_begin.
175+
size_t count_keys = map.keys.size();
176+
size_t count_keys_del = 0;
177+
size_t count_values_del = 0;
178+
for (int32_t i = map.keys.size() - 1; i >= 0; --i) {
179+
common_ngram_map_key & key = map.keys[i];
180+
if (key.key_idx >= map.size_last_begin) {
181+
// Delete the key.
182+
LOG_DBG("%s: delete key %d at index %zu (>= size_last_begin=%zu)\n", __func__, i, key.key_idx, map.size_last_begin);
183+
map.keys.erase(map.keys.begin() + i);
184+
count_keys_del++;
185+
continue;
186+
}
187+
if (map.key_only) {
188+
continue;
189+
}
190+
191+
// Check the indices of the values.
192+
for (int16_t j = COMMON_NGRAM_MAX_VALUES - 1; j >= 0; --j) {
193+
common_ngram_map_value & value = key.values[j];
194+
if (value.value_idx >= map.size_last_begin) {
195+
// Delete the value.
196+
count_values_del++;
197+
198+
// Move all values after this value to the left.
199+
for (uint16_t k = j; k < COMMON_NGRAM_MAX_VALUES - 1; ++k) {
200+
key.values[k] = key.values[k + 1];
201+
}
202+
// Clear the last value.
203+
key.values[COMMON_NGRAM_MAX_VALUES - 1].value_idx = 0;
204+
key.values[COMMON_NGRAM_MAX_VALUES - 1].value_num = 0;
205+
}
206+
}
207+
if (key.values[0].value_idx == 0) {
208+
// No values left, delete the key.
209+
LOG_DBG("%s: delete key %d at index %zu (no values left)\n", __func__, i, key.key_idx);
210+
map.keys.erase(map.keys.begin() + i);
211+
count_keys_del++;
212+
}
213+
}
214+
215+
LOG_INF("%s: refresh map: idx_last_draft=%zu, new begin=%zu, #keys_checked=%zu, #keys_del=%zu, #values_del=%zu, #hashes_upd=%zu\n", __func__,
216+
map.idx_last_check, size_begin,
217+
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
218+
}
219+
220+
map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
221+
map.size_last_begin = size_begin;
222+
}
223+
118224
void common_ngram_map_draft(common_ngram_map & map,
119225
const llama_tokens & inp, llama_token sampled,
120226
llama_tokens & draft) {
@@ -129,6 +235,10 @@ void common_ngram_map_draft(common_ngram_map & map,
129235
if (cur_len < static_cast<size_t>(2 * n + m)) {
130236
return;
131237
}
238+
if (cur_len >= static_cast<size_t>(UINT32_MAX)) {
239+
// key_map uses uint32_t instead of size_t.
240+
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
241+
}
132242

133243
// Only check every check_rate tokens to save compute
134244
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
@@ -147,24 +257,92 @@ void common_ngram_map_draft(common_ngram_map & map,
147257

148258
// search for the key in the map
149259
size_t match_pos = 0;
150-
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
151-
bool match = true;
152-
for (size_t k = 0; k < n; ++k) {
153-
if (inp[j + k] != key_tokens[k]) {
154-
match = false;
155-
break;
260+
if (map.size_last_begin > cur_len) {
261+
GGML_ABORT("%s: map.size_last_begin > cur_len: %zu > %zu", __func__, map.size_last_begin, cur_len);
262+
}
263+
if (!map.key_map.empty()) {
264+
// Search for the key in the map key_map from hash of ngrams to index of ngram.
265+
uint32_t idx_hash = (common_ngram_map_hash(key_tokens, 0, n) % map.key_map.size());
266+
uint32_t idx_key = map.key_map[idx_hash];
267+
if (idx_key != 0 && idx_key < cur_len - n - m - 1) {
268+
// Check if the key matches the key at idx_key (because of possible collisions).
269+
bool match = true;
270+
for (size_t k = 0; k < n; ++k) {
271+
if (inp[idx_key + k] != key_tokens[k]) {
272+
match = false;
273+
break;
274+
}
275+
}
276+
LOG_DBG("%s: key hash %x -> idx_key %d: match %d\n", __func__, idx_hash, idx_key, match ? 1 : 0);
277+
if (match) {
278+
match_pos = idx_key;
156279
}
157280
}
158-
if (match) {
159-
match_pos = j;
160-
break;
281+
}
282+
if (match_pos == 0 && map.size_last_begin > (size_t) (n + m + 1)) {
283+
// Search for the key in [1, map.size_last_begin - n - m -1], descending.
284+
for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) {
285+
// Check if the key matches the key.
286+
bool match = true;
287+
for (size_t k = 0; k < n; ++k) {
288+
if (inp[j + k] != key_tokens[k]) {
289+
match = false;
290+
break;
291+
}
292+
}
293+
if (match) {
294+
match_pos = j;
295+
break;
296+
}
297+
}
298+
}
299+
if (match_pos == 0) {
300+
// In case of a reasoning chat, the part after size_last_begin may be deleted/reordered later.
301+
//
302+
// Search in [size_last_begin, cur_len - n - m - 1], descending.
303+
for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) {
304+
bool match = true;
305+
for (size_t k = 0; k < n; ++k) {
306+
if (inp[j + k] != key_tokens[k]) {
307+
match = false;
308+
break;
309+
}
310+
}
311+
if (match) {
312+
match_pos = j;
313+
break;
314+
}
161315
}
162316
}
163317
if (match_pos > 0) {
164-
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
318+
LOG_DBG("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
165319
cur_len, n, m, key_tokens.size(), sampled, match_pos);
166320
}
167321

322+
if (!map.key_map.empty()) {
323+
// Add hashes of new ngrams in key_map.
324+
//
325+
// Use the same order as above.
326+
if (map.size_last_begin > (size_t) (n + m + 1)) {
327+
for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) {
328+
// compute hash and store index of ngram at idx j in the map.
329+
uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size());
330+
if (map.key_map[idx_hash] == 0) {
331+
map.key_map[idx_hash] = j; // collisions may occur
332+
}
333+
}
334+
}
335+
336+
for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) {
337+
// compute hash and store index of ngram at idx j in the map.
338+
uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size());
339+
if (map.key_map[idx_hash] == 0) {
340+
map.key_map[idx_hash] = j;
341+
}
342+
}
343+
map.key_map_last_idx = std::max(static_cast<uint32_t>(cur_len - n - m - 1), map.key_map_last_idx);
344+
}
345+
168346
if (match_pos == 0) {
169347
return;
170348
}
@@ -215,8 +393,8 @@ void common_ngram_map_draft(common_ngram_map & map,
215393
draft.push_back(inp[match_pos + n + i]);
216394
}
217395

218-
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
219-
key_offset, curr_key.key_num, draft.size());
396+
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
397+
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
220398

221399
map.last_draft_created = false;
222400
map.last_draft_key_idx = key_offset;
@@ -318,7 +496,7 @@ void common_ngram_map_draft(common_ngram_map & map,
318496
}
319497
}
320498

321-
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
499+
if (sum_occur > 0 && max_occur < 2 * sum_occur) {
322500
// The most frequent value is not much more frequent than the other values.
323501
// We do not use the draft.
324502
return;

common/ngram-map.h

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
1010
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
1111
//
12+
// ref: https://github.com/ggml-org/llama.cpp/pull/18471
13+
//
1214

1315
#include "llama.h"
1416
#include "common.h"
@@ -51,10 +53,13 @@ llama_tokens common_ngram_simple_draft(
5153
// maximum number of m-gram values stored for each key n-gram.
5254
#define COMMON_NGRAM_MAX_VALUES 4
5355

56+
// number of entries in the (optional, size 0 to disable) map from ngram-hash to ngram-index.
57+
#define COMMON_NGRAM_HASH_MAP_SIZE 262144
58+
5459
// statistics of a m-gram after a known n-gram
5560
struct common_ngram_map_value {
56-
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
57-
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
61+
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
62+
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
5863
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
5964
};
6065

@@ -74,23 +79,43 @@ struct common_ngram_map {
7479

7580
bool key_only; // true if only key n-grams are used, no values.
7681

77-
// first draft: vector only, no map.
7882
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
7983
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
8084
uint16_t min_hits; // minimum number of key hits to consider a draft
8185

86+
bool show_key_map_stats = false; // true, if statitics of the key_map should be printed.
87+
8288
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
8389
uint16_t check_rate, uint16_t min_hits)
8490
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
85-
check_rate(check_rate), min_hits(min_hits) {}
91+
check_rate(check_rate), min_hits(min_hits) {
92+
key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used
93+
}
94+
95+
// In reasoning chats the previous reasoning block will be removed from context history.
96+
// A rebuild of the ngram map is needed after that.
97+
98+
size_t size_last_begin = 0; // number of tokens at previous start of generation
8699

87100
bool last_draft_created = false; // true if a draft was created at last call.
88-
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
101+
size_t last_draft_key_idx = 0; // index of last key used for draft generation (0 = no draft)
89102
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
90103

91104
size_t idx_last_check = 0; // index of last check in context history
105+
106+
// optional map "hash to ngram-index" for faster lookup of n-grams. map is empty if unused.
107+
//
108+
// uint32_t instead of size_t (size of current histories is << UINT32_MAX)
109+
std::vector<uint32_t> key_map; // key_map[hash] = index of ngram in context window
110+
uint32_t key_map_last_idx = 0; // index of the last ngram added to key_map
92111
};
93112

113+
// Initialize the n-gram map with the given token history.
114+
// map: the ngram map to initialize.
115+
// tokens: the token history to base the map on.
116+
void common_ngram_map_begin(
117+
common_ngram_map & map,
118+
const llama_tokens & tokens);
94119

95120
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
96121
// map: the ngram map to search in.

0 commit comments

Comments
 (0)