Skip to content

Commit 66e0c9a

Browse files
committed
cont : simplify
1 parent 0c6b7a0 commit 66e0c9a

4 files changed

Lines changed: 90 additions & 234 deletions

File tree

common/speculative.cpp

Lines changed: 2 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,8 @@ common_speculative * common_speculative_init(
10811081
}
10821082

10831083
auto * result = new common_speculative {
1084-
/* .impls = */ std::move(impls)
1084+
/* .impls = */ std::move(impls),
1085+
/* .curr_impl = */ nullptr,
10851086
};
10861087

10871088
return result;
@@ -1187,146 +1188,3 @@ void common_speculative_print_stats(const common_speculative * spec) {
11871188
str_perf.c_str());
11881189
}
11891190
}
1190-
1191-
struct common_speculative_session::impl {
1192-
common_params_speculative params;
1193-
1194-
common_speculative * spec = nullptr;
1195-
1196-
bool has_partial = false;
1197-
1198-
llama_tokens draft;
1199-
1200-
impl(
1201-
const common_params_speculative & params,
1202-
llama_context * ctx_tgt) : params(params) {
1203-
spec = common_speculative_init(this->params, ctx_tgt);
1204-
}
1205-
1206-
void begin(const llama_tokens & prompt_history) const {
1207-
common_speculative_begin(spec, prompt_history);
1208-
}
1209-
1210-
bool generate_draft(
1211-
const llama_tokens & tokens,
1212-
llama_token id_last,
1213-
const int n_draft_max) {
1214-
GGML_ASSERT(spec);
1215-
1216-
if (n_draft_max == 0) {
1217-
this->clear();
1218-
return false;
1219-
}
1220-
1221-
if (has_partial) {
1222-
if (draft.empty()) {
1223-
this->clear();
1224-
}
1225-
1226-
LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__, tokens.size(), id_last, draft.size());
1227-
1228-
return false;
1229-
}
1230-
1231-
// call the speculative implementation to create a draft
1232-
draft = common_speculative_draft(spec, params, tokens, id_last);
1233-
LOG_DBG("draft: id_last=%d, #draft=%zu\n", id_last, draft.size());
1234-
1235-
if (draft.empty()) {
1236-
this->clear();
1237-
return false;
1238-
}
1239-
1240-
if (draft.size() > (size_t) n_draft_max) {
1241-
LOG_WRN("draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
1242-
draft.resize(n_draft_max);
1243-
}
1244-
1245-
if (draft.size() < (size_t) params.n_min) {
1246-
LOG_DBG("ignoring small draft: %d < %d\n", (int) draft.size(), params.n_min);
1247-
this->clear();
1248-
return false;
1249-
}
1250-
1251-
return true;
1252-
}
1253-
1254-
bool accept(llama_tokens ids) {
1255-
LOG_WRN("%s: n_draft=%zu, ids.size=%zu\n", __func__, draft.size(), ids.size());
1256-
1257-
has_partial = false;
1258-
1259-
if (ids.size() < draft.size() + 1) {
1260-
// the main model rejected some tokens
1261-
if (params.use_checkpoints) {
1262-
// shorten the draft to the number of accepted tokens
1263-
draft.resize(ids.size() - 1);
1264-
1265-
has_partial = true;
1266-
1267-
return false;
1268-
}
1269-
1270-
LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, draft.size(), draft.size());
1271-
}
1272-
1273-
draft = std::move(ids);
1274-
1275-
common_speculative_accept(spec, draft.size());
1276-
1277-
return true;
1278-
}
1279-
1280-
void print_stats() const {
1281-
GGML_ASSERT(spec);
1282-
1283-
common_speculative_print_stats(spec);
1284-
}
1285-
1286-
void clear() {
1287-
GGML_ASSERT(spec);
1288-
1289-
has_partial = false;
1290-
draft.clear();
1291-
}
1292-
};
1293-
1294-
common_speculative_session::common_speculative_session(
1295-
const common_params_speculative & params,
1296-
llama_context * ctx_tgt) : pimpl(new impl{params, ctx_tgt}) {
1297-
}
1298-
1299-
common_speculative_session::~common_speculative_session() {
1300-
common_speculative_free(pimpl->spec);
1301-
}
1302-
1303-
bool common_speculative_session::fail() const {
1304-
return pimpl->spec == nullptr;
1305-
}
1306-
1307-
void common_speculative_session::begin(const llama_tokens & prompt_history) {
1308-
pimpl->begin(prompt_history);
1309-
}
1310-
1311-
bool common_speculative_session::generate_draft(
1312-
const llama_tokens & prompt,
1313-
llama_token id_last,
1314-
int n_draft_max) {
1315-
return pimpl->generate_draft(prompt, id_last, n_draft_max);
1316-
}
1317-
1318-
bool common_speculative_session::accept(llama_tokens ids) {
1319-
return pimpl->accept(std::move(ids));
1320-
}
1321-
1322-
const llama_tokens & common_speculative_session::get_draft() const {
1323-
return pimpl->draft;
1324-
}
1325-
1326-
void common_speculative_session::print_stats() const {
1327-
pimpl->print_stats();
1328-
}
1329-
1330-
void common_speculative_session::clear() {
1331-
pimpl->clear();
1332-
}

common/speculative.h

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
#include "llama.h"
44
#include "common.h"
55

6-
// common/speculative.h has two interfaces:
7-
//
8-
// 1) struct common_speculative with init, begin, draft, accept and print_stats
9-
// Simple interface, see examples/speculative/speculative.cpp
10-
//
11-
// 2) struct common_speculative_session with struct common_speculative_callback
12-
// Complex interface which supports checkpoints, see tools/server/server-context.cpp
13-
//
14-
156
struct common_speculative;
167

178
// comma separated list of all types
@@ -55,37 +46,8 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
5546
// print statistics about the speculative decoding
5647
void common_speculative_print_stats(const common_speculative * spec);
5748

58-
// speculative decoding which may use checkpoints to rewind in tokens history
59-
struct common_speculative_session {
60-
common_speculative_session(
61-
const common_params_speculative & params,
62-
llama_context * ctx_tgt);
63-
64-
~common_speculative_session();
65-
66-
// no implementations available
67-
bool fail() const;
68-
69-
// call once at the beginning of a new generation
70-
// some spec implementations use the prompt history to initialize lookup maps
71-
void begin(const llama_tokens & prompt_history);
72-
73-
// do speculative decoding to compute a draft of tokens
74-
bool generate_draft(
75-
const llama_tokens & prompt,
76-
llama_token id_last,
77-
int n_draft_max);
78-
79-
// check if and how far the current draft is accepted
80-
bool accept(llama_tokens ids);
81-
82-
const llama_tokens & get_draft() const;
83-
84-
void print_stats() const;
85-
86-
void clear();
87-
88-
private:
89-
struct impl;
90-
std::unique_ptr<impl> pimpl;
49+
struct common_speculative_deleter {
50+
void operator()(common_speculative * s) { common_speculative_free(s); }
9151
};
52+
53+
typedef std::unique_ptr<common_speculative, common_speculative_deleter> common_speculative_ptr;

0 commit comments

Comments
 (0)