11#include " gated_delta_net.cuh"
22
3- template <int S_v, bool KDA, bool keep_intermediates_t >
3+ template <int S_v, bool KDA, bool keep_rs_t >
44__global__ void __launch_bounds__ ((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
55gated_delta_net_cuda(const float * q,
66 const float * k,
@@ -145,7 +145,7 @@ gated_delta_net_cuda(const float * q,
145145
146146 attn_data += S_v * H;
147147
148- if constexpr (keep_intermediates_t ) {
148+ if constexpr (keep_rs_t ) {
149149 const int target_slot = t - shift;
150150 if (target_slot >= 0 && target_slot < K) {
151151 float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
@@ -158,7 +158,7 @@ gated_delta_net_cuda(const float * q,
158158 }
159159 }
160160
161- if constexpr (!keep_intermediates_t ) {
161+ if constexpr (!keep_rs_t ) {
162162#pragma unroll
163163 for (int r = 0 ; r < rows_per_lane; r++) {
164164 const int i = r * warp_size + lane;
@@ -167,7 +167,7 @@ gated_delta_net_cuda(const float * q,
167167 }
168168}
169169
170- template <bool KDA, bool keep_intermediates_t >
170+ template <bool KDA, bool keep_rs_t >
171171static void launch_gated_delta_net (
172172 const float * q_d, const float * k_d, const float * v_d,
173173 const float * g_d, const float * b_d, const float * s_d,
@@ -191,26 +191,26 @@ static void launch_gated_delta_net(
191191
192192 switch (S_v) {
193193 case 16 :
194- gated_delta_net_cuda<16 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
194+ gated_delta_net_cuda<16 , KDA, keep_rs_t ><<<grid_dims, block_dims, 0 , stream>>> (
195195 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
196196 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
197197 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
198198 break ;
199199 case 32 :
200- gated_delta_net_cuda<32 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
200+ gated_delta_net_cuda<32 , KDA, keep_rs_t ><<<grid_dims, block_dims, 0 , stream>>> (
201201 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
202202 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
203203 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
204204 break ;
205205 case 64 : {
206- gated_delta_net_cuda<64 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
206+ gated_delta_net_cuda<64 , KDA, keep_rs_t ><<<grid_dims, block_dims, 0 , stream>>> (
207207 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
208208 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
209209 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
210210 break ;
211211 }
212212 case 128 : {
213- gated_delta_net_cuda<128 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
213+ gated_delta_net_cuda<128 , KDA, keep_rs_t ><<<grid_dims, block_dims, 0 , stream>>> (
214214 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
215215 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
216216 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
@@ -285,10 +285,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
285285
286286 // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
287287 const int K = (int ) src_state->ne [1 ];
288- const bool keep_intermediates = K > 1 ;
288+ const bool keep_rs = K > 1 ;
289289
290290 if (kda) {
291- if (keep_intermediates ) {
291+ if (keep_rs ) {
292292 launch_gated_delta_net<true , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
293293 S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
294294 sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
@@ -298,7 +298,7 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
298298 sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
299299 }
300300 } else {
301- if (keep_intermediates ) {
301+ if (keep_rs ) {
302302 launch_gated_delta_net<false , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
303303 S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
304304 sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
0 commit comments