Skip to content

Commit a8a33f6

Browse files
committed
review: rename, add asserts
1 parent 66e47d1 commit a8a33f6

7 files changed

Lines changed: 22 additions & 24 deletions

File tree

ggml/src/ggml-cuda/gated_delta_net.cu

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "gated_delta_net.cuh"
22

3-
template <int S_v, bool KDA, bool keep_intermediates_t>
3+
template <int S_v, bool KDA, bool keep_rs_t>
44
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
55
gated_delta_net_cuda(const float * q,
66
const float * k,
@@ -145,7 +145,7 @@ gated_delta_net_cuda(const float * q,
145145

146146
attn_data += S_v * H;
147147

148-
if constexpr (keep_intermediates_t) {
148+
if constexpr (keep_rs_t) {
149149
const int target_slot = t - shift;
150150
if (target_slot >= 0 && target_slot < K) {
151151
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
@@ -158,7 +158,7 @@ gated_delta_net_cuda(const float * q,
158158
}
159159
}
160160

161-
if constexpr (!keep_intermediates_t) {
161+
if constexpr (!keep_rs_t) {
162162
#pragma unroll
163163
for (int r = 0; r < rows_per_lane; r++) {
164164
const int i = r * warp_size + lane;
@@ -167,7 +167,7 @@ gated_delta_net_cuda(const float * q,
167167
}
168168
}
169169

170-
template <bool KDA, bool keep_intermediates_t>
170+
template <bool KDA, bool keep_rs_t>
171171
static void launch_gated_delta_net(
172172
const float * q_d, const float * k_d, const float * v_d,
173173
const float * g_d, const float * b_d, const float * s_d,
@@ -191,26 +191,26 @@ static void launch_gated_delta_net(
191191

192192
switch (S_v) {
193193
case 16:
194-
gated_delta_net_cuda<16, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
194+
gated_delta_net_cuda<16, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
195195
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
196196
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
197197
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
198198
break;
199199
case 32:
200-
gated_delta_net_cuda<32, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
200+
gated_delta_net_cuda<32, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
201201
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
202202
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
203203
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
204204
break;
205205
case 64: {
206-
gated_delta_net_cuda<64, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
206+
gated_delta_net_cuda<64, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
207207
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
208208
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
209209
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
210210
break;
211211
}
212212
case 128: {
213-
gated_delta_net_cuda<128, KDA, keep_intermediates_t><<<grid_dims, block_dims, 0, stream>>>(
213+
gated_delta_net_cuda<128, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
214214
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
215215
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
216216
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
@@ -285,10 +285,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
285285

286286
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
287287
const int K = (int) src_state->ne[1];
288-
const bool keep_intermediates = K > 1;
288+
const bool keep_rs = K > 1;
289289

290290
if (kda) {
291-
if (keep_intermediates) {
291+
if (keep_rs) {
292292
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
293293
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
294294
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
@@ -298,7 +298,7 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
298298
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
299299
}
300300
} else {
301-
if (keep_intermediates) {
301+
if (keep_rs) {
302302
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
303303
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
304304
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);

src/llama-arch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
878878
}
879879
}
880880

881-
bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch) {
881+
bool llm_arch_supports_rs_rollback(const llm_arch & arch) {
882882
switch (arch) {
883883
case LLM_ARCH_QWEN35:
884884
case LLM_ARCH_QWEN35MOE:

src/llama-arch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,4 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch);
637637
bool llm_arch_is_hybrid (const llm_arch & arch);
638638
bool llm_arch_is_diffusion (const llm_arch & arch);
639639
bool llm_arch_supports_sm_tensor(const llm_arch & arch);
640-
bool llm_arch_supports_recurrent_partial_rollback(const llm_arch & arch);
640+
bool llm_arch_supports_rs_rollback(const llm_arch & arch);

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ llama_context::llama_context(
5252
}
5353

5454
cparams.n_rs_seq = params.n_rs_seq;
55-
if (cparams.n_rs_seq > 0 && !llm_arch_supports_recurrent_partial_rollback(model.arch)) {
55+
if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) {
5656
LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
5757
__func__, cparams.n_rs_seq);
5858
cparams.n_rs_seq = 0;

src/llama-memory-recurrent.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,10 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
170170
// partial rollback via per-token snapshot index (bounded by n_rs_seq)
171171
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
172172
const llama_pos rollback = cell.pos - (p0 - 1);
173-
if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) {
174-
set_rs_idx(seq_id, (uint32_t) rollback);
175-
cell.pos = p0 - 1;
176-
return true;
177-
}
178-
return false;
173+
GGML_ASSERT(rollback >= 1 && rollback <= (llama_pos) n_rs_seq);
174+
set_rs_idx(seq_id, (uint32_t) rollback);
175+
cell.pos = p0 - 1;
176+
return true;
179177
}
180178
// invalidate tails which will be cleared
181179
if (p0 <= cell.pos && cell.pos < p1) {

src/models/delta-net-base.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
447447
return build_delta_net_chunking(q, k, v, g, b, s, il);
448448
}
449449

450-
bool llm_build_delta_net_base::keep_intermediates() const {
450+
bool llm_build_delta_net_base::keep_rs() const {
451451
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
452452
return cparams.n_rs_seq > 0
453453
&& n_seq_tokens > 1
@@ -466,7 +466,7 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state(
466466
const uint32_t mem_size = mctx_cur->get_size();
467467
const int64_t n_seqs = ubatch.n_seqs;
468468
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
469-
const bool keep = keep_intermediates();
469+
const bool keep = keep_rs();
470470

471471
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
472472
cb(conv_states, "conv_states", il);
@@ -531,7 +531,7 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
531531
const int64_t n_seqs = s->ne[3];
532532
const int64_t n_seq_tokens = q->ne[2];
533533

534-
if (!keep_intermediates()) {
534+
if (!keep_rs()) {
535535
auto attn_out = build_delta_net(q, k, v, g, b, s, il);
536536
ggml_tensor * output = attn_out.first;
537537
ggml_tensor * new_state = attn_out.second;

src/models/models.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct llm_build_delta_net_base : public llm_graph_context {
6767
int il);
6868

6969
// true when speculative rollback is enabled and the batch fits in the rs cache
70-
bool keep_intermediates() const;
70+
bool keep_rs() const;
7171

7272
// read conv state from cache, concat with qkv_mixed, write back (single slot or per-token)
7373
// qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs)

0 commit comments

Comments
 (0)