Skip to content

Commit c5e0227

Browse files
committed
review: rename rollback to rs_seq and remove public API
1 parent 589490f commit c5e0227

15 files changed

Lines changed: 47 additions & 65 deletions

common/common.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,11 +1420,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
14201420
goto done;
14211421
}
14221422

1423-
// bounded-rollback architectures: classify before the seq_rm probe, since
1424-
// the probe (distance = 1) would silently take the rollback path and look
1425-
// like unbounded PART support
1426-
if (llama_n_rollback_max(ctx) > 0 &&
1427-
llama_model_supports_recurrent_partial_rollback(llama_get_model(ctx))) {
1423+
if (llama_n_rs_seq(ctx) > 0) {
14281424
res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED;
14291425
goto done;
14301426
}
@@ -1503,7 +1499,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
15031499
// TODO: add for MTP
15041500
const bool has_spec = (params.speculative.type != COMMON_SPECULATIVE_TYPE_NONE)
15051501
|| params.speculative.has_dft();
1506-
cparams.n_rollback_max = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
1502+
cparams.n_rs_seq = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
15071503
}
15081504
cparams.n_batch = params.n_batch;
15091505
cparams.n_ubatch = params.n_ubatch;

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ enum common_context_seq_rm_type {
882882
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
883883
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
884884
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
885-
COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rollback_max
885+
COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq
886886
};
887887

888888
// check if the llama_context can remove sequences

include/llama.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ extern "C" {
333333
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
334334
uint32_t n_ubatch; // physical maximum batch size
335335
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
336-
uint32_t n_rollback_max; // max recurrent-state rollback distance (0 = no rollback support)
336+
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback)
337337
int32_t n_threads; // number of threads to use for generation
338338
int32_t n_threads_batch; // number of threads to use for batch processing
339339

@@ -531,7 +531,7 @@ extern "C" {
531531
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
532532
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
533533
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
534-
LLAMA_API uint32_t llama_n_rollback_max (const struct llama_context * ctx);
534+
LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx);
535535

536536
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
537537
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
@@ -623,8 +623,6 @@ extern "C" {
623623
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
624624
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
625625

626-
LLAMA_API bool llama_model_supports_recurrent_partial_rollback(const struct llama_model * model);
627-
628626
// Returns 0 on success
629627
LLAMA_API uint32_t llama_model_quantize(
630628
const char * fname_inp,

src/llama-context.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ llama_context::llama_context(
4242
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
4343
}
4444

45-
cparams.n_rollback_max = params.n_rollback_max;
46-
if (cparams.n_rollback_max > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) {
47-
LLAMA_LOG_WARN("%s: n_rollback_max=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
48-
__func__, cparams.n_rollback_max);
49-
cparams.n_rollback_max = 0;
45+
cparams.n_rs_seq = params.n_rs_seq;
46+
if (cparams.n_rs_seq > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) {
47+
LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
48+
__func__, cparams.n_rs_seq);
49+
cparams.n_rs_seq = 0;
5050
}
5151

5252
cparams.n_threads = params.n_threads;
@@ -2953,7 +2953,7 @@ llama_context_params llama_context_default_params() {
29532953
/*.n_batch =*/ 2048,
29542954
/*.n_ubatch =*/ 512,
29552955
/*.n_seq_max =*/ 1,
2956-
/*.n_rollback_max =*/ 0,
2956+
/*.n_rs_seq =*/ 0,
29572957
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
29582958
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
29592959
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@@ -3100,8 +3100,8 @@ uint32_t llama_n_seq_max(const llama_context * ctx) {
31003100
return ctx->n_seq_max();
31013101
}
31023102

3103-
uint32_t llama_n_rollback_max(const llama_context * ctx) {
3104-
return ctx->get_cparams().n_rollback_max;
3103+
uint32_t llama_n_rs_seq(const llama_context * ctx) {
3104+
return ctx->get_cparams().n_rs_seq;
31053105
}
31063106

31073107
const llama_model * llama_get_model(const llama_context * ctx) {

src/llama-cparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct llama_cparams {
1212
uint32_t n_batch;
1313
uint32_t n_ubatch;
1414
uint32_t n_seq_max;
15-
uint32_t n_rollback_max; // max recurrent-state rollback distance
15+
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
1616
int32_t n_threads; // number of threads to use for generation
1717
int32_t n_threads_batch; // number of threads to use for batch processing
1818

src/llama-memory-hybrid-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
2424
uint32_t rs_size,
2525
/* common */
2626
uint32_t n_seq_max,
27-
uint32_t n_rollback_max,
27+
uint32_t n_rs_seq,
2828
bool offload,
2929
bool unified,
3030
/* layer filters */
@@ -55,7 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
5555
offload,
5656
rs_size,
5757
n_seq_max,
58-
n_rollback_max,
58+
n_rs_seq,
5959
filter_recr == nullptr ?
6060
[&](int32_t il) { return hparams.is_recurrent(il); }
6161
: filter_recr

src/llama-memory-hybrid-iswa.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i {
3434
uint32_t rs_size,
3535
/* common */
3636
uint32_t n_seq_max,
37-
uint32_t n_rollback_max,
37+
uint32_t n_rs_seq,
3838
bool offload,
3939
bool unified,
4040
/* layer filters */

src/llama-memory-hybrid.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid(
2424
uint32_t rs_size,
2525
/* common */
2626
uint32_t n_seq_max,
27-
uint32_t n_rollback_max,
27+
uint32_t n_rs_seq,
2828
bool offload,
2929
bool unified,
3030
/* layer filters */
@@ -55,7 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid(
5555
offload,
5656
rs_size,
5757
n_seq_max,
58-
n_rollback_max,
58+
n_rs_seq,
5959
filter_recr == nullptr ?
6060
[&](int32_t il) { return hparams.is_recurrent(il); }
6161
: filter_recr

src/llama-memory-hybrid.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class llama_memory_hybrid : public llama_memory_i {
3434
uint32_t rs_size,
3535
/* common */
3636
uint32_t n_seq_max,
37-
uint32_t n_rollback_max,
37+
uint32_t n_rs_seq,
3838
bool offload,
3939
bool unified,
4040
/* layer filters */

src/llama-memory-recurrent.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ llama_memory_recurrent::llama_memory_recurrent(
2424
bool offload,
2525
uint32_t mem_size,
2626
uint32_t n_seq_max,
27-
uint32_t n_rollback_max,
27+
uint32_t n_rs_seq,
2828
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
2929
const int32_t n_layer = hparams.n_layer;
3030

3131
head = 0;
3232
size = mem_size;
3333
used = 0;
3434

35-
this->n_rollback_max = n_rollback_max;
36-
recurrent_rollback_idx.assign(n_seq_max, 0);
35+
this->n_rs_seq = n_rs_seq;
36+
rs_idx.assign(n_seq_max, 0);
3737

3838
cells.clear();
3939
cells.resize(mem_size);
@@ -96,7 +96,7 @@ llama_memory_recurrent::llama_memory_recurrent(
9696
throw std::runtime_error("failed to create ggml context for rs cache");
9797
}
9898

99-
const uint32_t n_rows = mem_size * (1 + n_rollback_max);
99+
const uint32_t n_rows = mem_size * (1 + n_rs_seq);
100100
ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows);
101101
ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows);
102102
ggml_format_name(r, "cache_r_l%d", i);
@@ -167,11 +167,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
167167
if (tail_id >= 0) {
168168
auto & cell = cells[tail_id];
169169

170-
// partial rollback via per-token snapshot index (bounded by n_rollback_max)
170+
// partial rollback via per-token snapshot index (bounded by n_rs_seq)
171171
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
172172
const llama_pos rollback = cell.pos - (p0 - 1);
173-
if (rollback >= 1 && rollback <= (llama_pos) n_rollback_max) {
174-
set_recurrent_rollback_idx(seq_id, (uint32_t) rollback);
173+
if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) {
174+
set_rs_idx(seq_id, (uint32_t) rollback);
175175
cell.pos = p0 - 1;
176176
return true;
177177
}
@@ -378,18 +378,11 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
378378
return result;
379379
}
380380

381-
void llama_memory_recurrent::set_recurrent_rollback_idx(llama_seq_id seq_id, uint32_t idx) {
382-
if (seq_id < 0 || (size_t) seq_id >= recurrent_rollback_idx.size()) {
381+
void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) {
382+
if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) {
383383
return;
384384
}
385-
recurrent_rollback_idx[seq_id] = (idx > n_rollback_max) ? n_rollback_max : idx;
386-
}
387-
388-
uint32_t llama_memory_recurrent::get_recurrent_rollback_idx(llama_seq_id seq_id) const {
389-
if (seq_id < 0 || (size_t) seq_id >= recurrent_rollback_idx.size()) {
390-
return 0;
391-
}
392-
return recurrent_rollback_idx[seq_id];
385+
rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx;
393386
}
394387

395388
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
@@ -1186,17 +1179,17 @@ int32_t llama_memory_recurrent_context::s_copy(int i) const {
11861179
const uint32_t cell_idx = i + mem->head;
11871180
const int32_t src0 = mem->cells[cell_idx].src0;
11881181

1189-
if (mem->n_rollback_max == 0) {
1182+
if (mem->n_rs_seq == 0) {
11901183
return src0;
11911184
}
11921185

11931186
uint32_t idx = 0;
11941187
if (!mem->cells[cell_idx].seq_id.empty()) {
11951188
const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin();
1196-
if (seq >= 0 && (size_t) seq < mem->recurrent_rollback_idx.size()) {
1197-
idx = mem->recurrent_rollback_idx[seq];
1189+
if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) {
1190+
idx = mem->rs_idx[seq];
11981191
// reset rollback idx
1199-
mem->recurrent_rollback_idx[seq] = 0;
1192+
mem->rs_idx[seq] = 0;
12001193
}
12011194
}
12021195
return (int32_t)(idx * mem->size) + src0;

0 commit comments

Comments
 (0)