Skip to content

Commit 9293647

Browse files
committed
common: ngram map, config self-speculative decoding
1 parent a7e8bc1 commit 9293647

8 files changed

Lines changed: 538 additions & 41 deletions

File tree

common/arg.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3393,10 +3393,46 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33933393
}
33943394
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
33953395
add_opt(common_arg(
3396-
{"--spec-self"}, "<0|1>",
3397-
"use self-speculation without a draft model (default: 0, no self speculation without draft model)",
3396+
{"--spec-self"}, "N",
3397+
"mode of self-speculation without a draft model: disabled(0), fixed(1), keys-only(2), key-values(3) (default: %d)\n",
33983398
[](common_params & params, int value) {
3399-
params.speculative.use_self = value;
3399+
if (value < 0 || value > 3) {
3400+
throw std::invalid_argument("invalid value");
3401+
}
3402+
params.speculative.self_mode = value;
3403+
}
3404+
).set_examples({LLAMA_EXAMPLE_SERVER}));
3405+
add_opt(common_arg(
3406+
{"--spec-self-config"}, "N0,N1,N2,...",
3407+
"speculative self decoding config: ngram size (key), mgram size (value), check rate, min hits (default: %d,%d,%d,%d)",
3408+
[](common_params & params, const std::string & value) {
3409+
std::string arg_next = value;
3410+
3411+
// split string by , and /
3412+
const std::regex regex{ R"([,/]+)" };
3413+
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
3414+
std::vector<std::string> split_arg{ it, {} };
3415+
if (split_arg.size() > 4) {
3416+
throw std::invalid_argument(
3417+
string_format("got %d input configs, but self-speculative decoding config require at most 4 values", (int)split_arg.size())
3418+
);
3419+
}
3420+
for (size_t i = 0; i < split_arg.size(); ++i) {
3421+
int val = std::stoi(split_arg[i]);
3422+
if (i == 0 && (val < 1 || val > 255)) {
3423+
throw std::invalid_argument("ngram size must be between 1 and 255");
3424+
}
3425+
if (i == 1 && (val < 1 || val > 255)) {
3426+
throw std::invalid_argument("mgram size must be between 1 and 255");
3427+
}
3428+
if (i == 2 && val == 0) {
3429+
throw std::invalid_argument("check rate must be greater than 0");
3430+
}
3431+
if (i == 3 && (val < 1 || val > 255)) {
3432+
throw std::invalid_argument("min hits must be between 1 and 255");
3433+
}
3434+
params.speculative.self_cfg[i] = (uint16_t) val;
3435+
}
34003436
}
34013437
).set_examples({LLAMA_EXAMPLE_SERVER}));
34023438
add_opt(common_arg(

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ struct common_params_speculative {
249249
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
250250
float p_split = 0.1f; // speculative decoding split probability
251251
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
252-
int32_t use_self = 0; // use self-speculative decoding without draft model (default: 0 = off)
252+
int32_t self_mode = 0; // mode of self-speculative decoding without draft model (default: 0 = off)
253+
std::vector<uint16_t> self_cfg = {12, 48, 2, 1}; // self-speculative decoding config (n-gram size, m-gram size, check rate, min hits)
253254
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
254255
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
255256

common/ngram-map.cpp

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#include "ngram-map.h"
2+
#include "common.h"
3+
#include "log.h"
4+
5+
#include <cinttypes>
6+
#include <cstdint>
7+
#include <cstdio>
8+
9+
// maximum number of counted values of a ngram map value.
10+
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
11+
12+
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
13+
14+
void common_ngram_map_draft(common_ngram_map & map,
15+
const llama_tokens & inp, llama_token sampled,
16+
llama_tokens & draft) {
17+
// reset last key and value.
18+
map.last_draft_created = false;
19+
map.last_draft_key_idx = 0;
20+
map.last_draft_value_idx = 0;
21+
22+
const size_t cur_len = inp.size();
23+
const uint16_t n = map.size_key;
24+
const uint16_t m = map.size_value;
25+
if (cur_len < static_cast<size_t>(2 * n + m)) {
26+
return;
27+
}
28+
29+
// Only check every check_rate tokens to save compute
30+
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
31+
if (map.idx_last_check + map.check_rate > cur_len) {
32+
return;
33+
}
34+
map.idx_last_check = cur_len;
35+
36+
// search pattern, the key n-gram
37+
std::vector<llama_token> key_tokens;
38+
key_tokens.reserve(n);
39+
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
40+
key_tokens.push_back(inp[j]);
41+
}
42+
key_tokens.push_back(sampled);
43+
44+
// search for the key in the map
45+
size_t match_pos = 0;
46+
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
47+
bool match = true;
48+
for (size_t k = 0; k < n; ++k) {
49+
if (inp[j + k] != key_tokens[k]) {
50+
match = false;
51+
break;
52+
}
53+
}
54+
if (match) {
55+
match_pos = j;
56+
break;
57+
}
58+
}
59+
if (match_pos > 0) {
60+
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
61+
cur_len, n, m, key_tokens.size(), sampled, match_pos);
62+
}
63+
64+
if (match_pos == 0) {
65+
return;
66+
}
67+
68+
// We have a match, now we look for the statistics of the key.
69+
size_t key_offset = map.keys.size(); // offset in the map
70+
// We iterate through the std::vector<common_ngram_map_key> map->keys.
71+
for (size_t i = 0; i < map.keys.size(); ++i) {
72+
bool match = true;
73+
for (size_t j = 0; j < n; ++j) {
74+
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
75+
match = false;
76+
break;
77+
}
78+
}
79+
if (match) {
80+
key_offset = i;
81+
break;
82+
}
83+
}
84+
if (key_offset == map.keys.size()) {
85+
// We create a new key-entry, it will get offset key_offset.
86+
common_ngram_map_key new_key;
87+
new_key.key_idx = match_pos;
88+
new_key.stat_idx = 0;
89+
new_key.key_num = 0;
90+
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
91+
new_key.values[i].value_num = 0;
92+
new_key.values[i].n_accepted = m;
93+
}
94+
map.keys.push_back(new_key);
95+
}
96+
97+
// our key n-gram:
98+
common_ngram_map_key & curr_key = map.keys[key_offset];
99+
100+
// update number of key hits
101+
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
102+
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
103+
104+
if (map.key_only) {
105+
// simple mode:
106+
// Fill in the draft with the m tokens following the key.
107+
// We work with value values[0] only.
108+
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
109+
110+
for (int i = 0; i < n_draft_tokens; ++i) {
111+
draft.push_back(inp[match_pos + n + i]);
112+
}
113+
114+
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
115+
key_offset, curr_key.key_num, draft.size());
116+
117+
map.last_draft_created = false;
118+
map.last_draft_key_idx = key_offset;
119+
map.last_draft_value_idx = 0; // value 0 is used for simple mode
120+
map.drafts_generated_tokens += draft.size();
121+
return;
122+
}
123+
124+
if (curr_key.key_num < map.min_hits) {
125+
// not enough hits to consider this a good draft
126+
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
127+
key_offset, curr_key.key_num, map.min_hits);
128+
return;
129+
}
130+
131+
// complex mode: examine the different m-grams after this key n-gram.
132+
//
133+
134+
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
135+
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
136+
// begins the key n-gram at index i?
137+
bool match_key = true;
138+
for (size_t k = 0; k < n; ++k) {
139+
if (inp[i + k] != key_tokens[k]) {
140+
match_key = false;
141+
break;
142+
}
143+
}
144+
if (!match_key) {
145+
continue;
146+
}
147+
148+
// Do we haven a existing value m-gram or a new one after the key at index i?
149+
size_t idx_begin_value_key = i + n;
150+
int idx_value = -1;
151+
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
152+
size_t idx_begin_value_v = curr_key.values[v].value_idx;
153+
if (idx_begin_value_v == 0) {
154+
// We found an empty value slot => we found a new value m-gram after the key n-gram.
155+
curr_key.values[v].value_idx = idx_begin_value_key;
156+
curr_key.values[v].value_num = 0;
157+
curr_key.values[v].n_accepted = m;
158+
idx_value = v;
159+
break;
160+
}
161+
bool match = true;
162+
for (size_t j = 0; j < m; ++j) {
163+
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
164+
match = false;
165+
break;
166+
}
167+
}
168+
if (match) {
169+
// We found an existing value m-gram after the key n-gram.
170+
idx_value = v;
171+
break;
172+
}
173+
}
174+
if (idx_value >= 0) {
175+
// We found a value m-gram of the key n-gram.
176+
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
177+
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
178+
}
179+
}
180+
// the statistics are updated up to match_pos.
181+
curr_key.stat_idx = match_pos;
182+
183+
// Do we have a value we could use for the draft?
184+
uint16_t max_occur = 0;
185+
int slot_max = 0;
186+
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
187+
uint16_t curr_occur = curr_key.values[v].value_num;
188+
if (curr_occur > max_occur) {
189+
max_occur = curr_occur;
190+
slot_max = v;
191+
}
192+
}
193+
// What is sum of the other occurences?
194+
uint32_t sum_occur = 0;
195+
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
196+
if (v == slot_max) {
197+
continue;
198+
}
199+
uint16_t curr_occur = curr_key.values[v].value_num;
200+
sum_occur += curr_occur;
201+
}
202+
203+
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
204+
key_offset,
205+
max_occur, sum_occur, slot_max,
206+
curr_key.values[0].value_idx, curr_key.values[0].value_num,
207+
curr_key.values[1].value_idx, curr_key.values[1].value_num,
208+
curr_key.values[2].value_idx, curr_key.values[2].value_num,
209+
curr_key.values[3].value_idx, curr_key.values[3].value_num
210+
);
211+
// Print the tokens of the four values (if idx != 0), use LOG_INF
212+
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
213+
if (curr_key.values[v].value_idx != 0) {
214+
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
215+
}
216+
}
217+
218+
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
219+
// The most frequent value is not much more frequent than the other values.
220+
// We do not use the draft.
221+
return;
222+
}
223+
224+
// We use the most frequent value values[slot_max] for the draft.
225+
// Fill in the draft with the m tokens following the key.
226+
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
227+
228+
for (int i = 0; i < n_draft_tokens; ++i) {
229+
draft.push_back(inp[match_pos + n + i]);
230+
}
231+
232+
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
233+
key_offset, slot_max,
234+
curr_key.key_num, draft.size());
235+
236+
map.last_draft_created = true;
237+
map.last_draft_key_idx = key_offset;
238+
map.last_draft_value_idx = slot_max; // value used for draft generation.
239+
map.drafts_generated_tokens += draft.size();
240+
}
241+
242+
void common_ngram_map_send_accepted(common_ngram_map & map, uint16_t n_accepted) {
243+
if (!map.last_draft_created) {
244+
return;
245+
}
246+
247+
// find the key and its chosen value.
248+
const size_t key_idx = map.last_draft_key_idx;
249+
const size_t val_idx = map.last_draft_value_idx;
250+
251+
// find key corresponding to key_idx.
252+
common_ngram_map_key & curr_key = map.keys[key_idx];
253+
// find value corresponding to val_idx.
254+
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
255+
256+
// update the value statistics
257+
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
258+
n_accepted, curr_value.n_accepted);
259+
curr_value.n_accepted = n_accepted;
260+
261+
// draft statistics update
262+
if (n_accepted > 0) {
263+
map.drafts_accepted_count++;
264+
} else {
265+
map.drafts_rejected_count++;
266+
}
267+
map.drafts_accepted_tokens += n_accepted;
268+
}
269+
270+
// Display statistics of the ngram map.
271+
void common_ngram_map_print_stats(const common_ngram_map & map) {
272+
LOG_INF("ngram map: size_key = %d, size_value = %d, key_only = %s, min_hits = %d\n",
273+
map.size_key, map.size_value,
274+
map.key_only ? "true" : "false",
275+
map.min_hits);
276+
LOG_INF("drafts_accepted_count = %zu, drafts_rejected_count = %zu, drafts_generated_tokens = %zu, drafts_accepted_tokens = %zu\n",
277+
map.drafts_accepted_count, map.drafts_rejected_count,
278+
map.drafts_generated_tokens, map.drafts_accepted_tokens);
279+
}
280+
281+
// Helper functions.
282+
//
283+
284+
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
285+
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
286+
std::string result = "[";
287+
for (size_t i = 0; i < length; ++i) {
288+
if (i > 0) {
289+
result += ", ";
290+
}
291+
result += std::to_string(inp[start + i]);
292+
}
293+
result += "]";
294+
return result;
295+
}
296+

0 commit comments

Comments
 (0)