Skip to content

Commit 4e95702

Browse files
committed
llama: allow partial seq_rm for GDN models for speculative decoding
Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates.
1 parent f42e29f commit 4e95702

26 files changed

Lines changed: 412 additions & 93 deletions

common/common.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,11 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
14221422

14231423
// try to remove the last tokens
14241424
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
1425+
if (llama_n_rollback_max(ctx) > 0 &&
1426+
llama_model_supports_recurrent_partial_rollback(llama_get_model(ctx))) {
1427+
res = COMMON_CONTEXT_SEQ_RM_TYPE_PART;
1428+
goto done;
1429+
}
14251430
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
14261431
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
14271432
goto done;
@@ -1490,6 +1495,12 @@ struct llama_context_params common_context_params_to_llama(const common_params &
14901495

14911496
cparams.n_ctx = params.n_ctx;
14921497
cparams.n_seq_max = params.n_parallel;
1498+
{
1499+
// TODO: add for MTP
1500+
const bool has_spec = (params.speculative.type != COMMON_SPECULATIVE_TYPE_NONE)
1501+
|| params.speculative.has_dft();
1502+
cparams.n_rollback_max = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
1503+
}
14931504
cparams.n_batch = params.n_batch;
14941505
cparams.n_ubatch = params.n_ubatch;
14951506
cparams.n_threads = params.cpuparams.n_threads;

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2537,7 +2537,8 @@ extern "C" {
25372537
struct ggml_tensor * v,
25382538
struct ggml_tensor * g,
25392539
struct ggml_tensor * beta,
2540-
struct ggml_tensor * state);
2540+
struct ggml_tensor * state,
2541+
bool keep_intermediates);
25412542

25422543
// custom operators
25432544

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2933,7 +2933,9 @@ struct ggml_cplan ggml_graph_plan(
29332933
case GGML_OP_GATED_DELTA_NET:
29342934
{
29352935
const int64_t S_v = node->src[2]->ne[0];
2936-
cur = S_v * sizeof(float) * n_tasks;
2936+
const bool keep_intermediates = (((const int32_t *) node->op_params)[0] != 0);
2937+
const int64_t per_thread = S_v + (keep_intermediates ? S_v * S_v : 0);
2938+
cur = per_thread * sizeof(float) * n_tasks;
29372939
} break;
29382940
case GGML_OP_COUNT:
29392941
{

ggml/src/ggml-cpu/ops.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10467,16 +10467,20 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
1046710467

1046810468
const bool kda = (neg0 == S_v);
1046910469

10470-
// scratch layout per thread: [delta(S_v)]
10471-
const int64_t scratch_per_thread = S_v;
10470+
const bool keep_intermediates = (const bool) ggml_get_op_params_i32(dst, 0);
10471+
10472+
const int64_t per_thread = S_v + (keep_intermediates ? S_v * S_v : 0);
1047210473
const int ith = params->ith;
1047310474

10474-
float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
10475+
float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32;
10476+
float * state_work = keep_intermediates ? (delta + S_v) : nullptr;
1047510477

1047610478
// output layout: [attn_scores | new_states]
1047710479
// attn_scores: S_v * H * n_tokens * n_seqs floats
10478-
// new_states: S_v * S_v * H * n_seqs floats
10479-
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
10480+
// new_states: S_v * S_v * H * n_seqs floats (final only)
10481+
// S_v * S_v * H * n_seqs * T floats (T snaps, keep_intermediates)
10482+
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
10483+
const int64_t state_size_per_snap = S_v * S_v * H * n_seqs;
1048010484
float * attn_out_base = (float *)dst->data;
1048110485
float * state_out_base = (float *)dst->data + attn_score_elems;
1048210486

@@ -10499,9 +10503,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
1049910503
const int64_t iq3 = iv3 / rq3;
1050010504
const int64_t ik3 = iv3 / rk3;
1050110505

10502-
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
10506+
float * s_out = keep_intermediates
10507+
? state_work
10508+
: state_out_base + (iv3 * H + iv1) * S_v * S_v;
1050310509

10504-
// copy input state into output buffer and operate in-place
10510+
// copy input state into the working buffer and operate in-place
1050510511
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
1050610512
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
1050710513

@@ -10552,6 +10558,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
1055210558
}
1055310559

1055410560
attn_data += S_v * H; // advance to next token
10561+
10562+
if (keep_intermediates) {
10563+
float * curr_state_o = state_out_base + t * state_size_per_snap +
10564+
(iv3 * H + iv1) * S_v * S_v;
10565+
memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float));
10566+
}
1055510567
}
1055610568
}
1055710569
}

ggml/src/ggml-cuda/gated_delta_net.cu

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "gated_delta_net.cuh"
22

3-
template <int S_v, bool KDA>
3+
template <int S_v, bool KDA, bool keep_intermediates_t>
44
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
55
gated_delta_net_cuda(const float * q,
66
const float * k,
@@ -37,7 +37,8 @@ gated_delta_net_cuda(const float * q,
3737
float * attn_data = dst;
3838
float * state = dst + attn_score_elems;
3939

40-
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
40+
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
41+
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // keep_intermediates_t only
4142
state += state_offset;
4243
curr_state += state_offset + col * S_v;
4344
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
@@ -135,17 +136,27 @@ gated_delta_net_cuda(const float * q,
135136
}
136137

137138
attn_data += S_v * H;
139+
140+
if constexpr (keep_intermediates_t) {
141+
float * curr_state = (dst + attn_score_elems) + t * state_size_per_token + state_offset;
142+
#pragma unroll
143+
for (int r = 0; r < rows_per_lane; r++) {
144+
const int i = r * warp_size + lane;
145+
curr_state[col * S_v + i] = s_shard[r];
146+
}
147+
}
138148
}
139149

140-
// Write state back to global memory (transposed layout)
150+
if constexpr (!keep_intermediates_t) {
141151
#pragma unroll
142-
for (int r = 0; r < rows_per_lane; r++) {
143-
const int i = r * warp_size + lane;
144-
state[col * S_v + i] = s_shard[r];
152+
for (int r = 0; r < rows_per_lane; r++) {
153+
const int i = r * warp_size + lane;
154+
state[col * S_v + i] = s_shard[r];
155+
}
145156
}
146157
}
147158

148-
template <bool KDA>
159+
template <bool KDA, bool keep_intermediates_t>
149160
static void launch_gated_delta_net(
150161
const float * q_d, const float * k_d, const float * v_d,
151162
const float * g_d, const float * b_d, const float * s_d,
@@ -169,26 +180,26 @@ static void launch_gated_delta_net(
169180

170181
switch (S_v) {
171182
case 16:
172-
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
183+
gated_delta_net_cuda<16, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
173184
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
174185
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
175186
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
176187
break;
177188
case 32:
178-
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
189+
gated_delta_net_cuda<32, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
179190
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
180191
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
181192
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
182193
break;
183194
case 64: {
184-
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
195+
gated_delta_net_cuda<64, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
185196
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
186197
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
187198
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
188199
break;
189200
}
190201
case 128: {
191-
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
202+
gated_delta_net_cuda<128, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
192203
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
193204
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
194205
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
@@ -261,13 +272,27 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
261272

262273
cudaStream_t stream = ctx.stream();
263274

275+
const bool keep_intermediates = (((const int32_t *) dst->op_params)[0] != 0);
276+
264277
if (kda) {
265-
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
266-
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
267-
sb1, sb2, sb3, neqk1, rq3, scale, stream);
278+
if (keep_intermediates) {
279+
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
280+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
281+
sb1, sb2, sb3, neqk1, rq3, scale, stream);
282+
} else {
283+
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
284+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
285+
sb1, sb2, sb3, neqk1, rq3, scale, stream);
286+
}
268287
} else {
269-
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
270-
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
271-
sb1, sb2, sb3, neqk1, rq3, scale, stream);
288+
if (keep_intermediates) {
289+
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
290+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
291+
sb1, sb2, sb3, neqk1, rq3, scale, stream);
292+
} else {
293+
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
294+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
295+
sb1, sb2, sb3, neqk1, rq3, scale, stream);
296+
}
272297
}
273298
}

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16717,8 +16717,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1671716717
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
1671816718
src_clone[4], src_clone[5], src_clone[6]);
1671916719
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
16720+
const bool keep_intermediates = (((const int32_t *) tensor->op_params)[0] != 0);
1672016721
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
16721-
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
16722+
src_clone[2], src_clone[3], src_clone[4], src_clone[5], keep_intermediates);
1672216723
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
1672316724
src_clone[0]->flags = tensor->src[0]->flags;
1672416725
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],

ggml/src/ggml.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6171,7 +6171,8 @@ struct ggml_tensor * ggml_gated_delta_net(
61716171
struct ggml_tensor * v,
61726172
struct ggml_tensor * g,
61736173
struct ggml_tensor * beta,
6174-
struct ggml_tensor * state) {
6174+
struct ggml_tensor * state,
6175+
bool keep_intermediates) {
61756176
GGML_ASSERT(ggml_is_contiguous_rows(q));
61766177
GGML_ASSERT(ggml_is_contiguous_rows(k));
61776178
GGML_ASSERT(ggml_is_contiguous_rows(v));
@@ -6197,9 +6198,8 @@ struct ggml_tensor * ggml_gated_delta_net(
61976198

61986199
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
61996200

6200-
// concat output and new_state into a single tensor
6201-
// output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
6202-
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
6201+
const int64_t state_rows = keep_intermediates ? n_tokens * S_v * n_seqs : S_v * n_seqs;
6202+
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 };
62036203
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
62046204

62056205
result->op = GGML_OP_GATED_DELTA_NET;
@@ -6210,6 +6210,9 @@ struct ggml_tensor * ggml_gated_delta_net(
62106210
result->src[4] = beta;
62116211
result->src[5] = state;
62126212

6213+
int32_t flag = keep_intermediates ? 1 : 0;
6214+
ggml_set_op_params(result, &flag, sizeof(flag));
6215+
62136216
return result;
62146217
}
62156218

include/llama.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ extern "C" {
333333
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
334334
uint32_t n_ubatch; // physical maximum batch size
335335
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
336+
uint32_t n_rollback_max; // max recurrent-state rollback distance (0 = no rollback support)
336337
int32_t n_threads; // number of threads to use for generation
337338
int32_t n_threads_batch; // number of threads to use for batch processing
338339

@@ -530,6 +531,7 @@ extern "C" {
530531
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
531532
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
532533
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
534+
LLAMA_API uint32_t llama_n_rollback_max (const struct llama_context * ctx);
533535

534536
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
535537
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
@@ -621,6 +623,8 @@ extern "C" {
621623
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
622624
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
623625

626+
LLAMA_API bool llama_model_supports_recurrent_partial_rollback(const struct llama_model * model);
627+
624628
// Returns 0 on success
625629
LLAMA_API uint32_t llama_model_quantize(
626630
const char * fname_inp,

src/llama-arch.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
876876
}
877877
}
878878

879+
bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch) {
880+
switch (arch) {
881+
case LLM_ARCH_QWEN35:
882+
case LLM_ARCH_QWEN35MOE:
883+
return true;
884+
default:
885+
return false;
886+
}
887+
}
888+
879889
bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
880890
switch (arch) {
881891
case LLM_ARCH_GROK:

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch);
636636
bool llm_arch_is_hybrid (const llm_arch & arch);
637637
bool llm_arch_is_diffusion (const llm_arch & arch);
638638
bool llm_arch_supports_sm_tensor(const llm_arch & arch);
639+
bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch);

0 commit comments

Comments
 (0)