Skip to content

Commit 003c903

Browse files
committed
ngram-map : take into account the input can become shorter
1 parent 9f8401a commit 003c903

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

common/ngram-map.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ llama_tokens common_ngram_simple_draft(
2828
const size_t cur_len = tokens.size();
2929
// Only check every check_rate tokens to save compute
3030
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
31-
if (state.idx_last_check + state.config.check_rate > cur_len) {
31+
if (state.idx_last_check + state.config.check_rate > cur_len && cur_len > state.idx_last_check) {
3232
llama_tokens draft_tokens;
3333
return draft_tokens;
3434
}
@@ -54,7 +54,7 @@ llama_tokens common_ngram_simple_draft(
5454
pattern.push_back(sampled); // add the last token to the pattern
5555

5656
// We do a search in the token history.
57-
state.idx_last_check = tokens.size();
57+
state.idx_last_check = cur_len;
5858

5959
size_t match_pos = 0; // we ignore position 0, position 0 == no match
6060
// search backwards, but skip the current match (we are currently there)
@@ -100,15 +100,15 @@ llama_tokens common_ngram_simple_draft(
100100
// maximum number of counted values of a ngram map value.
101101
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
102102

103-
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
103+
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
104104

105105
void common_ngram_map_draft(common_ngram_map & map,
106106
const llama_tokens & inp, llama_token sampled,
107107
llama_tokens & draft) {
108108
// reset last key and value.
109-
map.last_draft_created = false;
110-
map.last_draft_key_idx = 0;
111-
map.last_draft_value_idx = 0;
109+
map.last_draft_created = false;
110+
map.last_draft_key_idx = 0;
111+
map.last_draft_value_idx = 0;
112112

113113
const size_t cur_len = inp.size();
114114
const uint16_t n = map.size_key;
@@ -119,7 +119,7 @@ void common_ngram_map_draft(common_ngram_map & map,
119119

120120
// Only check every check_rate tokens to save compute
121121
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
122-
if (map.idx_last_check + map.check_rate > cur_len) {
122+
if (map.idx_last_check + map.check_rate > cur_len && cur_len > map.idx_last_check) {
123123
return;
124124
}
125125
map.idx_last_check = cur_len;
@@ -205,9 +205,9 @@ void common_ngram_map_draft(common_ngram_map & map,
205205
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
206206
key_offset, curr_key.key_num, draft.size());
207207

208-
map.last_draft_created = false;
209-
map.last_draft_key_idx = key_offset;
210-
map.last_draft_value_idx = 0; // value 0 is used for simple mode
208+
map.last_draft_created = false;
209+
map.last_draft_key_idx = key_offset;
210+
map.last_draft_value_idx = 0; // value 0 is used for simple mode
211211
return;
212212
}
213213

@@ -323,9 +323,9 @@ void common_ngram_map_draft(common_ngram_map & map,
323323
key_offset, slot_max,
324324
curr_key.key_num, draft.size());
325325

326-
map.last_draft_created = true;
327-
map.last_draft_key_idx = key_offset;
328-
map.last_draft_value_idx = slot_max; // value used for draft generation.
326+
map.last_draft_created = true;
327+
map.last_draft_key_idx = key_offset;
328+
map.last_draft_value_idx = slot_max; // value used for draft generation.
329329
}
330330

331331
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {

common/ngram-map.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct common_ngram_map {
8080

8181
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
8282
uint16_t check_rate, uint16_t min_hits)
83-
: size_key(sz_key), size_value(sz_value), key_only(only_keys), keys(std::vector<common_ngram_map_key>{}),
83+
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
8484
check_rate(check_rate), min_hits(min_hits) {}
8585

8686
bool last_draft_created = false; // true if a draft was created at last call.

0 commit comments

Comments
 (0)