|
5 | 5 |
|
6 | 6 | struct common_speculative; |
7 | 7 |
|
| 8 | +// comma separated list the provided types |
| 9 | +std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types); |
| 10 | + |
8 | 11 | // comma separated list of all types |
9 | | -std::string common_speculative_type_name_str(); |
| 12 | +const char * common_speculative_all_types_str(); |
| 13 | + |
| 14 | +// parse user provided types |
| 15 | +std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names); |
10 | 16 |
|
11 | 17 | // convert string to type |
12 | 18 | enum common_speculative_type common_speculative_type_from_name(const std::string & name); |
13 | 19 |
|
14 | 20 | // convert type to string |
15 | 21 | std::string common_speculative_type_to_str(enum common_speculative_type type); |
16 | 22 |
|
17 | | -common_speculative * common_speculative_init( |
18 | | - common_params_speculative & params, |
19 | | - llama_context * ctx_tgt); |
| 23 | +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq); |
20 | 24 |
|
21 | 25 | void common_speculative_free(common_speculative * spec); |
22 | 26 |
|
| 27 | +struct common_speculative_draft_params { |
| 28 | + // this flag is used to chain the drafts through all the available implementations |
| 29 | + // after the first successful draft from an implementation, we set it |
| 30 | + // to false to prevent further drafts for that sequence |
| 31 | + // at the end of the draft() call, all drafting flags will be reset to false |
| 32 | + bool drafting = false; |
| 33 | + |
| 34 | + // overrides individual configurations (-1 disabled) |
| 35 | + // can be used to constraint the max draft based on the remaining context size |
| 36 | + int32_t n_max = -1; |
| 37 | + |
| 38 | + llama_pos n_past; |
| 39 | + llama_token id_last; |
| 40 | + |
| 41 | + // TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls |
| 42 | + const llama_tokens * prompt; |
| 43 | + |
| 44 | + // the generated draft from the last _draft() call |
| 45 | + llama_tokens * result; |
| 46 | +}; |
| 47 | + |
| 48 | +common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id); |
| 49 | + |
23 | 50 | // optionally call once at the beginning of a new generation |
24 | | -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); |
| 51 | +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); |
25 | 52 |
|
26 | | -// sample up to n_draft tokens and add them to the batch using the draft model |
27 | | -llama_tokens common_speculative_draft( |
28 | | - common_speculative * spec, |
29 | | - const common_params_speculative & params, |
30 | | - const llama_tokens & prompt, |
31 | | - llama_token id_last); |
| 53 | +// process the batch and update the internal state of the speculative context |
| 54 | +bool common_speculative_process(common_speculative * spec, const llama_batch & batch); |
32 | 55 |
|
33 | | -// informs the speculative decoder that n_accepted tokens were accepted by the target model |
34 | | -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); |
| 56 | +// generate drafts for the sequences specified with `common_speculative_get_draft_params` |
| 57 | +void common_speculative_draft(common_speculative * spec); |
35 | 58 |
|
36 | | -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params); |
37 | | -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params); |
| 59 | +// informs the speculative context that n_accepted tokens were accepted by the target model |
| 60 | +void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); |
38 | 61 |
|
39 | 62 | // print statistics about the speculative decoding |
40 | 63 | void common_speculative_print_stats(const common_speculative * spec); |
|
0 commit comments