77#include < cstdio>
88#include < sstream>
99
10+ // prime number used for LCG hash function (32 bit), it is near (sqrt(5) - 1)/2 * 2^32.
11+ #define LCG_FACTOR 2654435761UL
12+
13+ // Compute the LCG hash of a n-gram of size len at offset start.
14+ static uint32_t common_ngram_map_hash (const llama_tokens & tokens, size_t start, size_t len) {
15+ uint32_t hash = 0 ;
16+ for (size_t i = 0 ; i < len; ++i) {
17+ hash = hash * LCG_FACTOR + tokens[start + i];
18+ }
19+ return hash;
20+ }
21+
1022// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
1123static std::string common_tokens_to_str (const llama_tokens & inp, size_t start, size_t length) {
1224 std::ostringstream oss;
@@ -115,6 +127,100 @@ llama_tokens common_ngram_simple_draft(
115127// maximum number of counted values of a ngram map value.
116128#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
117129
130+ void common_ngram_map_begin (
131+ common_ngram_map & map, const llama_tokens & tokens) {
132+ size_t size_begin = tokens.size ();
133+
134+ LOG_DBG (" %s: begin, idx_last_draft=%zu, new begin=%zu, #keys=%zu\n " , __func__,
135+ map.idx_last_check , size_begin, map.keys .size ());
136+
137+ size_t count_map_entries_upd = 0 ;
138+ if (!map.key_map .empty () && size_begin < map.idx_last_check ) {
139+ if (map.show_key_map_stats ) {
140+ // Print statistics of hash map map_key.
141+ size_t count_nonzero = 0 ;
142+ uint32_t min_idx = UINT32_MAX;
143+ uint32_t max_idx = 0 ;
144+ for (size_t i = 0 ; i < map.key_map .size (); ++i) {
145+ uint32_t key_idx = map.key_map [i];
146+ if (key_idx != 0 ) {
147+ ++count_nonzero;
148+ if (key_idx < min_idx) min_idx = key_idx;
149+ if (key_idx > max_idx) max_idx = key_idx;
150+ }
151+ }
152+ if (count_nonzero == 0 ) {
153+ min_idx = 0 ;
154+ }
155+ LOG_INF (" %s: key_map stats: entries=%zu, min_idx=%u, max_idx=%u, key_map_last_idx=%u\n " ,
156+ __func__, count_nonzero, min_idx, max_idx, map.key_map_last_idx );
157+ }
158+
159+ // Update the map from hash to key index (clear outdated entries).
160+ for (size_t i = 0 ; i < map.key_map .size (); ++i) {
161+ uint32_t key_idx = map.key_map [i];
162+ if (key_idx >= map.size_last_begin ) {
163+ map.key_map [i] = 0 ;
164+ count_map_entries_upd++;
165+ }
166+ }
167+ map.key_map_last_idx = (map.size_last_begin > 0 ) ? map.size_last_begin - 1 : 0 ;
168+ }
169+
170+ if (size_begin < map.idx_last_check && !map.keys .empty ()) {
171+ // The next token generation will start at index size_begin.
172+ // The tokens between map.size_last_begin and size_begin are no longer valid.
173+ //
174+ // Refresh map: Remove all entries with index >= map.size_last_begin.
175+ size_t count_keys = map.keys .size ();
176+ size_t count_keys_del = 0 ;
177+ size_t count_values_del = 0 ;
178+ for (int32_t i = map.keys .size () - 1 ; i >= 0 ; --i) {
179+ common_ngram_map_key & key = map.keys [i];
180+ if (key.key_idx >= map.size_last_begin ) {
181+ // Delete the key.
182+ LOG_DBG (" %s: delete key %d at index %zu (>= size_last_begin=%zu)\n " , __func__, i, key.key_idx , map.size_last_begin );
183+ map.keys .erase (map.keys .begin () + i);
184+ count_keys_del++;
185+ continue ;
186+ }
187+ if (map.key_only ) {
188+ continue ;
189+ }
190+
191+ // Check the indices of the values.
192+ for (int16_t j = COMMON_NGRAM_MAX_VALUES - 1 ; j >= 0 ; --j) {
193+ common_ngram_map_value & value = key.values [j];
194+ if (value.value_idx >= map.size_last_begin ) {
195+ // Delete the value.
196+ count_values_del++;
197+
198+ // Move all values after this value to the left.
199+ for (uint16_t k = j; k < COMMON_NGRAM_MAX_VALUES - 1 ; ++k) {
200+ key.values [k] = key.values [k + 1 ];
201+ }
202+ // Clear the last value.
203+ key.values [COMMON_NGRAM_MAX_VALUES - 1 ].value_idx = 0 ;
204+ key.values [COMMON_NGRAM_MAX_VALUES - 1 ].value_num = 0 ;
205+ }
206+ }
207+ if (key.values [0 ].value_idx == 0 ) {
208+ // No values left, delete the key.
209+ LOG_DBG (" %s: delete key %d at index %zu (no values left)\n " , __func__, i, key.key_idx );
210+ map.keys .erase (map.keys .begin () + i);
211+ count_keys_del++;
212+ }
213+ }
214+
215+ LOG_INF (" %s: refresh map: idx_last_draft=%zu, new begin=%zu, #keys_checked=%zu, #keys_del=%zu, #values_del=%zu, #hashes_upd=%zu\n " , __func__,
216+ map.idx_last_check , size_begin,
217+ count_keys, count_keys_del, count_values_del, count_map_entries_upd);
218+ }
219+
220+ map.idx_last_check = (map.size_last_begin > 0 ) ? map.size_last_begin - 1 : 0 ;
221+ map.size_last_begin = size_begin;
222+ }
223+
118224void common_ngram_map_draft (common_ngram_map & map,
119225 const llama_tokens & inp, llama_token sampled,
120226 llama_tokens & draft) {
@@ -129,6 +235,10 @@ void common_ngram_map_draft(common_ngram_map & map,
129235 if (cur_len < static_cast <size_t >(2 * n + m)) {
130236 return ;
131237 }
238+ if (cur_len >= static_cast <size_t >(UINT32_MAX)) {
239+ // key_map uses uint32_t instead of size_t.
240+ GGML_ABORT (" %s: cur_len exceeds UINT32_MAX: %zu" , __func__, cur_len);
241+ }
132242
133243 // Only check every check_rate tokens to save compute
134244 // i.e., perform check if (cur_len - idx_last_check) >= check_rate
@@ -147,24 +257,92 @@ void common_ngram_map_draft(common_ngram_map & map,
147257
148258 // search for the key in the map
149259 size_t match_pos = 0 ;
150- for (size_t j = cur_len - n - m - 1 ; j > 0 ; --j) {
151- bool match = true ;
152- for (size_t k = 0 ; k < n; ++k) {
153- if (inp[j + k] != key_tokens[k]) {
154- match = false ;
155- break ;
260+ if (map.size_last_begin > cur_len) {
261+ GGML_ABORT (" %s: map.size_last_begin > cur_len: %zu > %zu" , __func__, map.size_last_begin , cur_len);
262+ }
263+ if (!map.key_map .empty ()) {
264+ // Search for the key in the map key_map from hash of ngrams to index of ngram.
265+ uint32_t idx_hash = (common_ngram_map_hash (key_tokens, 0 , n) % map.key_map .size ());
266+ uint32_t idx_key = map.key_map [idx_hash];
267+ if (idx_key != 0 && idx_key < cur_len - n - m - 1 ) {
268+ // Check if the key matches the key at idx_key (because of possible collisions).
269+ bool match = true ;
270+ for (size_t k = 0 ; k < n; ++k) {
271+ if (inp[idx_key + k] != key_tokens[k]) {
272+ match = false ;
273+ break ;
274+ }
275+ }
276+ LOG_DBG (" %s: key hash %x -> idx_key %d: match %d\n " , __func__, idx_hash, idx_key, match ? 1 : 0 );
277+ if (match) {
278+ match_pos = idx_key;
156279 }
157280 }
158- if (match) {
159- match_pos = j;
160- break ;
281+ }
282+ if (match_pos == 0 && map.size_last_begin > (size_t ) (n + m + 1 )) {
283+ // Search for the key in [1, map.size_last_begin - n - m -1], descending.
284+ for (size_t j = map.size_last_begin - n - m - 1 ; j > map.key_map_last_idx ; --j) {
285+ // Check if the key matches the key.
286+ bool match = true ;
287+ for (size_t k = 0 ; k < n; ++k) {
288+ if (inp[j + k] != key_tokens[k]) {
289+ match = false ;
290+ break ;
291+ }
292+ }
293+ if (match) {
294+ match_pos = j;
295+ break ;
296+ }
297+ }
298+ }
299+ if (match_pos == 0 ) {
300+ // In case of a reasoning chat, the part after size_last_begin may be deleted/reordered later.
301+ //
302+ // Search in [size_last_begin, cur_len - n - m - 1], descending.
303+ for (size_t j = cur_len - n - m - 1 ; j > map.size_last_begin && j > map.key_map_last_idx ; --j) {
304+ bool match = true ;
305+ for (size_t k = 0 ; k < n; ++k) {
306+ if (inp[j + k] != key_tokens[k]) {
307+ match = false ;
308+ break ;
309+ }
310+ }
311+ if (match) {
312+ match_pos = j;
313+ break ;
314+ }
161315 }
162316 }
163317 if (match_pos > 0 ) {
164- LOG_INF (" %s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n " , __func__,
318+ LOG_DBG (" %s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n " , __func__,
165319 cur_len, n, m, key_tokens.size (), sampled, match_pos);
166320 }
167321
322+ if (!map.key_map .empty ()) {
323+ // Add hashes of new ngrams in key_map.
324+ //
325+ // Use the same order as above.
326+ if (map.size_last_begin > (size_t ) (n + m + 1 )) {
327+ for (size_t j = map.size_last_begin - n - m - 1 ; j > map.key_map_last_idx ; --j) {
328+ // compute hash and store index of ngram at idx j in the map.
329+ uint32_t idx_hash = (common_ngram_map_hash (inp, j, n) % map.key_map .size ());
330+ if (map.key_map [idx_hash] == 0 ) {
331+ map.key_map [idx_hash] = j; // collisions may occur
332+ }
333+ }
334+ }
335+
336+ for (size_t j = cur_len - n - m - 1 ; j > map.size_last_begin && j > map.key_map_last_idx ; --j) {
337+ // compute hash and store index of ngram at idx j in the map.
338+ uint32_t idx_hash = (common_ngram_map_hash (inp, j, n) % map.key_map .size ());
339+ if (map.key_map [idx_hash] == 0 ) {
340+ map.key_map [idx_hash] = j;
341+ }
342+ }
343+ map.key_map_last_idx = std::max (static_cast <uint32_t >(cur_len - n - m - 1 ), map.key_map_last_idx );
344+ }
345+
168346 if (match_pos == 0 ) {
169347 return ;
170348 }
@@ -215,8 +393,8 @@ void common_ngram_map_draft(common_ngram_map & map,
215393 draft.push_back (inp[match_pos + n + i]);
216394 }
217395
218- LOG_INF (" %s: key_offset = %zu, key_num = %d, draft.size = %zu\n " , __func__,
219- key_offset, curr_key.key_num , draft.size ());
396+ LOG_DBG (" %s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n " , __func__,
397+ curr_key. key_idx , key_offset, curr_key.key_num , draft.size ());
220398
221399 map.last_draft_created = false ;
222400 map.last_draft_key_idx = key_offset;
@@ -318,7 +496,7 @@ void common_ngram_map_draft(common_ngram_map & map,
318496 }
319497 }
320498
321- if (sum_occur > 0 && max_occur < 3 * sum_occur) {
499+ if (sum_occur > 0 && max_occur < 2 * sum_occur) {
322500 // The most frequent value is not much more frequent than the other values.
323501 // We do not use the draft.
324502 return ;
0 commit comments