@@ -161,6 +161,10 @@ struct common_speculative_impl {
161161
162162 virtual void accept (llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
163163
164+ // (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary).
165+ virtual bool get_state (llama_seq_id /* seq_id*/ , std::vector<uint8_t > & /* data*/ ) const { return false ; }
166+ virtual void set_state (llama_seq_id /* seq_id*/ , const std::vector<uint8_t > & /* data*/ ) {}
167+
164168 // true if this implementation requires the target context to extract post-norm embeddings
165169 virtual bool need_embd () const = 0;
166170
@@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
841845 (size_t ) n_embd_dec * sizeof (float ));
842846 }
843847
848+ // we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
849+ // their single-position checkpoints drop it on restore
850+ bool need_boundary_stash () const {
851+ const llama_model * model_tgt = llama_get_model (params.ctx_tgt );
852+ return llama_model_is_recurrent (model_tgt) || llama_model_is_hybrid (model_tgt);
853+ }
854+
855+ bool get_state (llama_seq_id seq_id, std::vector<uint8_t > & data) const override {
856+ if (!need_boundary_stash ()) {
857+ return false ;
858+ }
859+ if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0 ) {
860+ return false ;
861+ }
862+
863+ const llama_pos pos = pending_pos_last[seq_id];
864+ const std::vector<float > & g = pending_g_last[seq_id];
865+
866+ data.resize (sizeof (llama_pos) + g.size () * sizeof (float ));
867+ std::memcpy (data.data (), &pos, sizeof (llama_pos));
868+ std::memcpy (data.data () + sizeof (llama_pos), g.data (), g.size () * sizeof (float ));
869+ return true ;
870+ }
871+
872+ void set_state (llama_seq_id seq_id, const std::vector<uint8_t > & data) override {
873+ if (!need_boundary_stash ()) {
874+ return ;
875+ }
876+ if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
877+ return ;
878+ }
879+ if (data.size () != sizeof (llama_pos) + (size_t ) n_embd_dec * sizeof (float )) {
880+ return ;
881+ }
882+
883+ llama_pos pos = -1 ;
884+ std::memcpy (&pos, data.data (), sizeof (llama_pos));
885+
886+ pending_pos_last[seq_id] = pos;
887+ pending_g_last[seq_id].resize (n_embd_dec);
888+ std::memcpy (pending_g_last[seq_id].data (), data.data () + sizeof (llama_pos), (size_t ) n_embd_dec * sizeof (float ));
889+ }
890+
844891 bool need_embd () const override {
845892 return false ;
846893 }
@@ -2118,6 +2165,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
21182165 }
21192166}
21202167
2168+ // TODO: support the case of more than one speculative implementations having a state
2169+ bool common_speculative_get_state (common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t > & data) {
2170+ if (spec == nullptr ) {
2171+ return false ;
2172+ }
2173+
2174+ for (auto & impl : spec->impls ) {
2175+ if (impl->get_state (seq_id, data)) {
2176+ return true ;
2177+ }
2178+ }
2179+
2180+ return false ;
2181+ }
2182+
2183+ void common_speculative_set_state (common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t > & data) {
2184+ if (spec == nullptr ) {
2185+ return ;
2186+ }
2187+
2188+ for (auto & impl : spec->impls ) {
2189+ impl->set_state (seq_id, data);
2190+ }
2191+ }
2192+
21212193void common_speculative_print_stats (const common_speculative * spec) {
21222194 if (spec == nullptr ) {
21232195 return ;
0 commit comments