@@ -38,8 +38,12 @@ int main(int argc, char ** argv) {
3838 std::string result0;
3939 std::string result1;
4040 std::string result2;
41+ std::string result3;
4142
4243 // init
44+
45+ ggml_backend_load_all ();
46+
4347 auto llama_init = common_init_from_params (params);
4448
4549 auto * model = llama_init->model ();
@@ -213,11 +217,83 @@ int main(int argc, char ** argv) {
213217 n_past += 1 ;
214218 }
215219
220+ // test on-device state save/load
221+ auto params_ctx4 = common_context_params_to_llama (params);
222+ params_ctx4.n_seq_max = 2 ;
223+ llama_context * ctx4 = llama_init_from_model (model, params_ctx4);
224+
225+ llama_sampler * smpl4 = llama_sampler_chain_init (sparams);
226+
227+ llama_sampler_chain_add (smpl4, llama_sampler_init_dist (params.sampling .seed ));
228+
229+ printf (" \n single seq run: %s" , params.prompt .c_str ());
230+
231+ // load state (rng, logits, embedding and kv_cache) from file
232+ n_token_count_out = 0 ;
233+
234+ if (!llama_state_load_file (ctx4, state_file.data (), unused_sts.data (), unused_sts.size (), &n_token_count_out)) {
235+ fprintf (stderr, " \n %s : failed to load state\n " , __func__);
236+ return 1 ;
237+ }
238+
239+ fprintf (stderr, " %s : loaded state with %zu tokens\n " , __func__, n_token_count_out);
240+
241+ // restore state (last tokens)
242+ n_past = n_token_count_out;
243+ if (!common_replay_last_token (ctx4, tokens.back (), n_past)) {
244+ return 1 ;
245+ }
246+ ++n_past;
247+
248+ // save seq 0 and load into seq 1
249+ {
250+ // save kv of seq 0
251+ std::vector<uint8_t > seq_store (llama_state_seq_get_size_ext (ctx4, 0 , LLAMA_STATE_SEQ_FLAGS_ON_DEVICE ));
252+ const size_t ncopy = llama_state_seq_get_data_ext (ctx4, seq_store.data (), seq_store.size (), 0 , LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
253+ if (ncopy != seq_store.size ()) {
254+ fprintf (stderr, " \n %s : seq copy data length %zd does not match expected length %zd\n " , __func__, ncopy, seq_store.size ());
255+ return 1 ;
256+ }
257+ fprintf (stderr, " %s : seq 0 copied, %zd bytes\n " , __func__, ncopy);
258+
259+ // erase whole kv
260+ llama_memory_clear (llama_get_memory (ctx4), true );
261+ fprintf (stderr, " %s : kv cache cleared\n " , __func__);
262+
263+ // restore kv into seq 0
264+ const size_t nset = llama_state_seq_set_data_ext (ctx4, seq_store.data (), seq_store.size (), 1 , LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
265+ if (nset != seq_store.size ()) {
266+ fprintf (stderr, " \n %s : seq set data length %zd does not match expected length %zd\n " , __func__, nset, seq_store.size ());
267+ return 1 ;
268+ }
269+ fprintf (stderr, " %s : seq 1 restored, %zd bytes\n " , __func__, nset);
270+ }
271+
272+ // forth run
273+ for (auto i = 0 ; i < params.n_predict ; i++) {
274+ auto next_token = llama_sampler_sample (smpl4, ctx4, -1 );
275+ auto next_token_str = common_token_to_piece (ctx4, next_token);
276+
277+ printf (" %s" , next_token_str.c_str ());
278+ result3 += next_token_str;
279+
280+ common_batch_clear (batch);
281+ common_batch_add (batch, next_token, n_past, {1 }, true );
282+
283+ if (llama_decode (ctx4, batch)) {
284+ fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
285+ llama_batch_free (batch);
286+ return 1 ;
287+ }
288+ n_past += 1 ;
289+ }
290+
216291 printf (" \n " );
217292
218293 llama_sampler_free (smpl);
219294 llama_sampler_free (smpl2);
220295 llama_sampler_free (smpl3);
296+ llama_sampler_free (smpl4);
221297
222298 llama_batch_free (batch);
223299
@@ -226,12 +302,18 @@ int main(int argc, char ** argv) {
226302
227303 llama_free (ctx2);
228304 llama_free (ctx3);
305+ llama_free (ctx4);
229306
230307 if (result0 != result2) {
231308 fprintf (stderr, " \n %s : error : the seq restore generation is different\n " , __func__);
232309 return 1 ;
233310 }
234311
312+ if (result0 != result3) {
313+ fprintf (stderr, " \n %s : error : the seq restore generation is different\n " , __func__);
314+ return 1 ;
315+ }
316+
235317 fprintf (stderr, " \n %s : success\n " , __func__);
236318
237319 return 0 ;
0 commit comments