@@ -24,16 +24,16 @@ llama_memory_recurrent::llama_memory_recurrent(
2424 bool offload,
2525 uint32_t mem_size,
2626 uint32_t n_seq_max,
27- uint32_t n_rollback_max ,
27+ uint32_t n_rs_seq ,
2828 const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
2929 const int32_t n_layer = hparams.n_layer ;
3030
3131 head = 0 ;
3232 size = mem_size;
3333 used = 0 ;
3434
35- this ->n_rollback_max = n_rollback_max ;
36- recurrent_rollback_idx .assign (n_seq_max, 0 );
35+ this ->n_rs_seq = n_rs_seq ;
36+ rs_idx .assign (n_seq_max, 0 );
3737
3838 cells.clear ();
3939 cells.resize (mem_size);
@@ -96,7 +96,7 @@ llama_memory_recurrent::llama_memory_recurrent(
9696 throw std::runtime_error (" failed to create ggml context for rs cache" );
9797 }
9898
99- const uint32_t n_rows = mem_size * (1 + n_rollback_max );
99+ const uint32_t n_rows = mem_size * (1 + n_rs_seq );
100100 ggml_tensor * r = ggml_new_tensor_2d (ctx, type_r, hparams.n_embd_r (), n_rows);
101101 ggml_tensor * s = ggml_new_tensor_2d (ctx, type_s, hparams.n_embd_s (), n_rows);
102102 ggml_format_name (r, " cache_r_l%d" , i);
@@ -167,11 +167,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
167167 if (tail_id >= 0 ) {
168168 auto & cell = cells[tail_id];
169169
170- // partial rollback via per-token snapshot index (bounded by n_rollback_max )
170+ // partial rollback via per-token snapshot index (bounded by n_rs_seq )
171171 if (0 < p0 && p0 <= cell.pos && p1 > cell.pos ) {
172172 const llama_pos rollback = cell.pos - (p0 - 1 );
173- if (rollback >= 1 && rollback <= (llama_pos) n_rollback_max ) {
174- set_recurrent_rollback_idx (seq_id, (uint32_t ) rollback);
173+ if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq ) {
174+ set_rs_idx (seq_id, (uint32_t ) rollback);
175175 cell.pos = p0 - 1 ;
176176 return true ;
177177 }
@@ -378,18 +378,11 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
378378 return result;
379379}
380380
381- void llama_memory_recurrent::set_recurrent_rollback_idx (llama_seq_id seq_id, uint32_t idx) {
382- if (seq_id < 0 || (size_t ) seq_id >= recurrent_rollback_idx .size ()) {
381+ void llama_memory_recurrent::set_rs_idx (llama_seq_id seq_id, uint32_t idx) {
382+ if (seq_id < 0 || (size_t ) seq_id >= rs_idx .size ()) {
383383 return ;
384384 }
385- recurrent_rollback_idx[seq_id] = (idx > n_rollback_max) ? n_rollback_max : idx;
386- }
387-
388- uint32_t llama_memory_recurrent::get_recurrent_rollback_idx (llama_seq_id seq_id) const {
389- if (seq_id < 0 || (size_t ) seq_id >= recurrent_rollback_idx.size ()) {
390- return 0 ;
391- }
392- return recurrent_rollback_idx[seq_id];
385+ rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx;
393386}
394387
395388std::map<ggml_backend_buffer_type_t , size_t > llama_memory_recurrent::memory_breakdown () const {
@@ -1186,17 +1179,17 @@ int32_t llama_memory_recurrent_context::s_copy(int i) const {
11861179 const uint32_t cell_idx = i + mem->head ;
11871180 const int32_t src0 = mem->cells [cell_idx].src0 ;
11881181
1189- if (mem->n_rollback_max == 0 ) {
1182+ if (mem->n_rs_seq == 0 ) {
11901183 return src0;
11911184 }
11921185
11931186 uint32_t idx = 0 ;
11941187 if (!mem->cells [cell_idx].seq_id .empty ()) {
11951188 const llama_seq_id seq = *mem->cells [cell_idx].seq_id .begin ();
1196- if (seq >= 0 && (size_t ) seq < mem->recurrent_rollback_idx .size ()) {
1197- idx = mem->recurrent_rollback_idx [seq];
1189+ if (seq >= 0 && (size_t ) seq < mem->rs_idx .size ()) {
1190+ idx = mem->rs_idx [seq];
11981191 // reset rollback idx
1199- mem->recurrent_rollback_idx [seq] = 0 ;
1192+ mem->rs_idx [seq] = 0 ;
12001193 }
12011194 }
12021195 return (int32_t )(idx * mem->size ) + src0;
0 commit comments