@@ -1081,7 +1081,8 @@ common_speculative * common_speculative_init(
10811081 }
10821082
10831083 auto * result = new common_speculative {
1084- /* .impls = */ std::move (impls)
1084+ /* .impls = */ std::move (impls),
1085+ /* .curr_impl = */ nullptr ,
10851086 };
10861087
10871088 return result;
@@ -1187,146 +1188,3 @@ void common_speculative_print_stats(const common_speculative * spec) {
11871188 str_perf.c_str ());
11881189 }
11891190}
1190-
1191- struct common_speculative_session ::impl {
1192- common_params_speculative params;
1193-
1194- common_speculative * spec = nullptr ;
1195-
1196- bool has_partial = false ;
1197-
1198- llama_tokens draft;
1199-
1200- impl (
1201- const common_params_speculative & params,
1202- llama_context * ctx_tgt) : params(params) {
1203- spec = common_speculative_init (this ->params , ctx_tgt);
1204- }
1205-
1206- void begin (const llama_tokens & prompt_history) const {
1207- common_speculative_begin (spec, prompt_history);
1208- }
1209-
1210- bool generate_draft (
1211- const llama_tokens & tokens,
1212- llama_token id_last,
1213- const int n_draft_max) {
1214- GGML_ASSERT (spec);
1215-
1216- if (n_draft_max == 0 ) {
1217- this ->clear ();
1218- return false ;
1219- }
1220-
1221- if (has_partial) {
1222- if (draft.empty ()) {
1223- this ->clear ();
1224- }
1225-
1226- LOG_DBG (" %s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n " , __func__, tokens.size (), id_last, draft.size ());
1227-
1228- return false ;
1229- }
1230-
1231- // call the speculative implementation to create a draft
1232- draft = common_speculative_draft (spec, params, tokens, id_last);
1233- LOG_DBG (" draft: id_last=%d, #draft=%zu\n " , id_last, draft.size ());
1234-
1235- if (draft.empty ()) {
1236- this ->clear ();
1237- return false ;
1238- }
1239-
1240- if (draft.size () > (size_t ) n_draft_max) {
1241- LOG_WRN (" draft size %d exceeds max %d, truncating\n " , (int ) draft.size (), n_draft_max);
1242- draft.resize (n_draft_max);
1243- }
1244-
1245- if (draft.size () < (size_t ) params.n_min ) {
1246- LOG_DBG (" ignoring small draft: %d < %d\n " , (int ) draft.size (), params.n_min );
1247- this ->clear ();
1248- return false ;
1249- }
1250-
1251- return true ;
1252- }
1253-
1254- bool accept (llama_tokens ids) {
1255- LOG_WRN (" %s: n_draft=%zu, ids.size=%zu\n " , __func__, draft.size (), ids.size ());
1256-
1257- has_partial = false ;
1258-
1259- if (ids.size () < draft.size () + 1 ) {
1260- // the main model rejected some tokens
1261- if (params.use_checkpoints ) {
1262- // shorten the draft to the number of accepted tokens
1263- draft.resize (ids.size () - 1 );
1264-
1265- has_partial = true ;
1266-
1267- return false ;
1268- }
1269-
1270- LOG_DBG (" %s: partial acceptance: %zu < %zu\n " , __func__, draft.size (), draft.size ());
1271- }
1272-
1273- draft = std::move (ids);
1274-
1275- common_speculative_accept (spec, draft.size ());
1276-
1277- return true ;
1278- }
1279-
1280- void print_stats () const {
1281- GGML_ASSERT (spec);
1282-
1283- common_speculative_print_stats (spec);
1284- }
1285-
1286- void clear () {
1287- GGML_ASSERT (spec);
1288-
1289- has_partial = false ;
1290- draft.clear ();
1291- }
1292- };
1293-
1294- common_speculative_session::common_speculative_session (
1295- const common_params_speculative & params,
1296- llama_context * ctx_tgt) : pimpl(new impl{params, ctx_tgt}) {
1297- }
1298-
1299- common_speculative_session::~common_speculative_session () {
1300- common_speculative_free (pimpl->spec );
1301- }
1302-
1303- bool common_speculative_session::fail () const {
1304- return pimpl->spec == nullptr ;
1305- }
1306-
1307- void common_speculative_session::begin (const llama_tokens & prompt_history) {
1308- pimpl->begin (prompt_history);
1309- }
1310-
1311- bool common_speculative_session::generate_draft (
1312- const llama_tokens & prompt,
1313- llama_token id_last,
1314- int n_draft_max) {
1315- return pimpl->generate_draft (prompt, id_last, n_draft_max);
1316- }
1317-
1318- bool common_speculative_session::accept (llama_tokens ids) {
1319- return pimpl->accept (std::move (ids));
1320- }
1321-
1322- const llama_tokens & common_speculative_session::get_draft () const {
1323- return pimpl->draft ;
1324- }
1325-
1326- void common_speculative_session::print_stats () const {
1327- pimpl->print_stats ();
1328- }
1329-
1330- void common_speculative_session::clear () {
1331- pimpl->clear ();
1332- }
0 commit comments