@@ -38,6 +38,7 @@ int main(int argc, char ** argv) {
3838 std::string result0;
3939 std::string result1;
4040 std::string result2;
41+ std::string result3;
4142
4243 // init
4344 auto llama_init = common_init_from_params (params);
@@ -213,11 +214,83 @@ int main(int argc, char ** argv) {
213214 n_past += 1 ;
214215 }
215216
217+ // test on-device state save/load
218+ auto params_ctx4 = common_context_params_to_llama (params);
219+ params_ctx4.n_seq_max = 2 ;
220+ llama_context * ctx4 = llama_init_from_model (model, params_ctx4);
221+
222+ llama_sampler * smpl4 = llama_sampler_chain_init (sparams);
223+
224+ llama_sampler_chain_add (smpl4, llama_sampler_init_dist (params.sampling .seed ));
225+
226+ printf (" \n single seq run: %s" , params.prompt .c_str ());
227+
228+ // load state (rng, logits, embedding and kv_cache) from file
229+ n_token_count_out = 0 ;
230+
231+ if (!llama_state_load_file (ctx4, state_file.data (), unused_sts.data (), unused_sts.size (), &n_token_count_out)) {
232+ fprintf (stderr, " \n %s : failed to load state\n " , __func__);
233+ return 1 ;
234+ }
235+
236+ fprintf (stderr, " %s : loaded state with %zu tokens\n " , __func__, n_token_count_out);
237+
238+ // restore state (last tokens)
239+ n_past = n_token_count_out;
240+ if (!common_replay_last_token (ctx4, tokens.back (), n_past)) {
241+ return 1 ;
242+ }
243+ ++n_past;
244+
245+ // save seq 0 and load into seq 1
246+ {
247+ // save kv of seq 0
248+ std::vector<uint8_t > seq_store (llama_state_seq_get_size_ext (ctx4, 0 , LLAMA_STATE_SEQ_FLAGS_ON_DEVICE));
249+ const size_t ncopy = llama_state_seq_get_data_ext (ctx4, seq_store.data (), seq_store.size (), 0 , LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
250+ if (ncopy != seq_store.size ()) {
251+ fprintf (stderr, " \n %s : seq copy data length %zd does not match expected length %zd\n " , __func__, ncopy, seq_store.size ());
252+ return 1 ;
253+ }
254+ fprintf (stderr, " %s : seq 0 copied, %zd bytes\n " , __func__, ncopy);
255+
256+ // erase whole kv
257+ llama_memory_clear (llama_get_memory (ctx4), true );
258+ fprintf (stderr, " %s : kv cache cleared\n " , __func__);
259+
260+ // restore kv into seq 0
261+ const size_t nset = llama_state_seq_set_data_ext (ctx4, seq_store.data (), seq_store.size (), 1 , LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
262+ if (nset != seq_store.size ()) {
263+ fprintf (stderr, " \n %s : seq set data length %zd does not match expected length %zd\n " , __func__, nset, seq_store.size ());
264+ return 1 ;
265+ }
266+ fprintf (stderr, " %s : seq 1 restored, %zd bytes\n " , __func__, nset);
267+ }
268+
269+ // forth run
270+ for (auto i = 0 ; i < params.n_predict ; i++) {
271+ auto next_token = llama_sampler_sample (smpl4, ctx4, -1 );
272+ auto next_token_str = common_token_to_piece (ctx4, next_token);
273+
274+ printf (" %s" , next_token_str.c_str ());
275+ result3 += next_token_str;
276+
277+ common_batch_clear (batch);
278+ common_batch_add (batch, next_token, n_past, {1 }, true );
279+
280+ if (llama_decode (ctx4, batch)) {
281+ fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
282+ llama_batch_free (batch);
283+ return 1 ;
284+ }
285+ n_past += 1 ;
286+ }
287+
216288 printf (" \n " );
217289
218290 llama_sampler_free (smpl);
219291 llama_sampler_free (smpl2);
220292 llama_sampler_free (smpl3);
293+ llama_sampler_free (smpl4);
221294
222295 llama_batch_free (batch);
223296
@@ -226,12 +299,18 @@ int main(int argc, char ** argv) {
226299
227300 llama_free (ctx2);
228301 llama_free (ctx3);
302+ llama_free (ctx4);
229303
230304 if (result0 != result2) {
231305 fprintf (stderr, " \n %s : error : the seq restore generation is different\n " , __func__);
232306 return 1 ;
233307 }
234308
309+ if (result0 != result3) {
310+ fprintf (stderr, " \n %s : error : the seq restore generation is different\n " , __func__);
311+ return 1 ;
312+ }
313+
235314 fprintf (stderr, " \n %s : success\n " , __func__);
236315
237316 return 0 ;
0 commit comments