Skip to content

Commit c402b3d

Browse files
author
Petros Sideris
committed
spec: save the dynamic/static ngram cache file
* fix todo on providing n_draft, save_static and save_dynamic from common/common.h * implement the functionality by saving the cache at the common_speculative_state_ngram_cache destruction
1 parent 52f1096 commit c402b3d

2 files changed

Lines changed: 28 additions & 12 deletions

File tree

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ struct common_params_speculative {
308308

309309
// ngram-based speculative decoding
310310

311+
uint16_t ngram_n_draft = 8; // ngram n tokens to draft
311312
uint16_t ngram_size_n = 12; // ngram size for lookup
312313
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
313314
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
@@ -317,6 +318,9 @@ struct common_params_speculative {
317318
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
318319
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
319320

321+
bool save_lookup_cache_static = false; // whether or not we should save the static ngram cache file // NOLINT
322+
bool save_lookup_cache_dynamic = false; // whether or not we should save the dynamic ngram cache file // NOLINT
323+
320324
// draft-model speculative decoding
321325

322326
struct common_params_model mparams_dft;

common/speculative.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,9 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
762762
bool save_dynamic;
763763
bool save_static;
764764

765+
const std::string path_static;
766+
const std::string path_dynamic;
767+
765768
common_ngram_cache ngram_cache_context;
766769
common_ngram_cache ngram_cache_dynamic;
767770
common_ngram_cache ngram_cache_static;
@@ -770,15 +773,17 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
770773

771774
common_speculative_state_ngram_cache(
772775
const enum common_speculative_type type,
776+
uint16_t n_draft,
773777
const std::string & path_static,
774778
const std::string & path_dynamic,
775-
uint16_t n_draft,
776779
bool save_dynamic,
777780
bool save_static)
778781
: common_speculative_state(type)
779782
, n_draft(n_draft)
780783
, save_dynamic(save_dynamic)
781784
, save_static(save_static)
785+
, path_static(path_static)
786+
, path_dynamic(path_dynamic)
782787
{
783788
if (!path_static.empty()) {
784789
try {
@@ -799,6 +804,15 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
799804
}
800805
}
801806

807+
~common_speculative_state_ngram_cache() override {
808+
if (save_static) {
809+
common_ngram_cache_save(ngram_cache_static, path_static);
810+
}
811+
if (save_dynamic) {
812+
common_ngram_cache_save(ngram_cache_dynamic, path_dynamic);
813+
}
814+
}
815+
802816
void begin(const llama_tokens & prompt) override {
803817
GGML_UNUSED(prompt);
804818
}
@@ -865,16 +879,15 @@ static common_ngram_map get_common_ngram_map(const common_speculative_config & c
865879
return common_ngram_map(size_key, size_value, key_only, min_hits);
866880
}
867881

868-
static common_speculative_state_ngram_cache create_state_ngram_cache(
869-
const std::string & path_static, const std::string & path_dynamic,
870-
const common_speculative_config & config) {
871-
uint16_t n_draft = 8; // TODO get from config?
872-
873-
// TODO bool param in common/common.h to set save_static/save_dynamic?
874-
bool save_static = false;
875-
bool save_dynamic = false;
882+
static common_speculative_state_ngram_cache create_state_ngram_cache(const common_speculative_config & config) {
876883

877-
common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic);
884+
common_speculative_state_ngram_cache state(
885+
config.type,
886+
config.params.ngram_n_draft,
887+
config.params.lookup_cache_static,
888+
config.params.lookup_cache_dynamic,
889+
config.params.save_lookup_cache_static,
890+
config.params.save_lookup_cache_dynamic);
878891

879892
return state;
880893
}
@@ -1031,8 +1044,7 @@ common_speculative * common_speculative_init(
10311044
break;
10321045
}
10331046
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
1034-
auto state = create_state_ngram_cache(
1035-
params.lookup_cache_static, params.lookup_cache_dynamic, config);
1047+
auto state = create_state_ngram_cache(config);
10361048
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
10371049
break;
10381050
}

0 commit comments

Comments
 (0)