Skip to content

Commit 89b10b8

Browse files
author
Petros Sideris
committed
spec : save the dynamic/static ngram cache file
* fix todo on providing n_draft, save_static and save_dynamic from common/common.h * implement the functionality by saving the cache at the common_speculative_state_ngram_cache destruction
1 parent 0f1bb60 commit 89b10b8

2 files changed

Lines changed: 28 additions & 12 deletions

File tree

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ struct common_params_speculative {
309309

310310
// ngram-based speculative decoding
311311

312+
uint16_t ngram_n_draft = 8; // ngram n tokens to draft
312313
uint16_t ngram_size_n = 12; // ngram size for lookup
313314
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
314315
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
@@ -318,6 +319,9 @@ struct common_params_speculative {
318319
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
319320
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
320321

322+
bool save_lookup_cache_static = false; // whether or not we should save the static ngram cache file // NOLINT
323+
bool save_lookup_cache_dynamic = false; // whether or not we should save the dynamic ngram cache file // NOLINT
324+
321325
// draft-model speculative decoding
322326

323327
struct common_params_model mparams_dft;

common/speculative.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,9 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
771771
bool save_dynamic;
772772
bool save_static;
773773

774+
const std::string path_static;
775+
const std::string path_dynamic;
776+
774777
common_ngram_cache ngram_cache_context;
775778
common_ngram_cache ngram_cache_dynamic;
776779
common_ngram_cache ngram_cache_static;
@@ -779,15 +782,17 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
779782

780783
common_speculative_state_ngram_cache(
781784
const enum common_speculative_type type,
785+
uint16_t n_draft,
782786
const std::string & path_static,
783787
const std::string & path_dynamic,
784-
uint16_t n_draft,
785788
bool save_dynamic,
786789
bool save_static)
787790
: common_speculative_state(type)
788791
, n_draft(n_draft)
789792
, save_dynamic(save_dynamic)
790793
, save_static(save_static)
794+
, path_static(path_static)
795+
, path_dynamic(path_dynamic)
791796
{
792797
if (!path_static.empty()) {
793798
try {
@@ -808,6 +813,15 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
808813
}
809814
}
810815

816+
~common_speculative_state_ngram_cache() override {
817+
if (save_static) {
818+
common_ngram_cache_save(ngram_cache_static, path_static);
819+
}
820+
if (save_dynamic) {
821+
common_ngram_cache_save(ngram_cache_dynamic, path_dynamic);
822+
}
823+
}
824+
811825
void begin(const llama_tokens & prompt) override {
812826
GGML_UNUSED(prompt);
813827
}
@@ -874,16 +888,15 @@ static common_ngram_map get_common_ngram_map(const common_speculative_config & c
874888
return common_ngram_map(size_key, size_value, key_only, min_hits);
875889
}
876890

877-
static common_speculative_state_ngram_cache create_state_ngram_cache(
878-
const std::string & path_static, const std::string & path_dynamic,
879-
const common_speculative_config & config) {
880-
uint16_t n_draft = 8; // TODO get from config?
881-
882-
// TODO bool param in common/common.h to set save_static/save_dynamic?
883-
bool save_static = false;
884-
bool save_dynamic = false;
891+
static common_speculative_state_ngram_cache create_state_ngram_cache(const common_speculative_config & config) {
885892

886-
common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic);
893+
common_speculative_state_ngram_cache state(
894+
config.type,
895+
config.params.ngram_n_draft,
896+
config.params.lookup_cache_static,
897+
config.params.lookup_cache_dynamic,
898+
config.params.save_lookup_cache_static,
899+
config.params.save_lookup_cache_dynamic);
887900

888901
return state;
889902
}
@@ -1040,8 +1053,7 @@ common_speculative * common_speculative_init(
10401053
break;
10411054
}
10421055
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
1043-
auto state = create_state_ngram_cache(
1044-
params.lookup_cache_static, params.lookup_cache_dynamic, config);
1056+
auto state = create_state_ngram_cache(config);
10451057
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
10461058
break;
10471059
}

0 commit comments

Comments
 (0)