11#include " gated_delta_net.cuh"
22
3- template <int S_v, bool KDA>
3+ template <int S_v, bool KDA, bool keep_intermediates_t >
44__global__ void __launch_bounds__ ((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
55gated_delta_net_cuda(const float * q,
66 const float * k,
@@ -37,7 +37,8 @@ gated_delta_net_cuda(const float * q,
3737 float * attn_data = dst;
3838 float * state = dst + attn_score_elems;
3939
40- const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
40+ const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
41+ const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // keep_intermediates_t only
4142 state += state_offset;
4243 curr_state += state_offset + col * S_v;
4344 attn_data += (sequence * n_tokens * H + h_idx) * S_v;
@@ -135,17 +136,27 @@ gated_delta_net_cuda(const float * q,
135136 }
136137
137138 attn_data += S_v * H;
139+
140+ if constexpr (keep_intermediates_t ) {
141+ float * curr_state = (dst + attn_score_elems) + t * state_size_per_token + state_offset;
142+ #pragma unroll
143+ for (int r = 0 ; r < rows_per_lane; r++) {
144+ const int i = r * warp_size + lane;
145+ curr_state[col * S_v + i] = s_shard[r];
146+ }
147+ }
138148 }
139149
140- // Write state back to global memory (transposed layout)
150+ if constexpr (! keep_intermediates_t ) {
141151#pragma unroll
142- for (int r = 0 ; r < rows_per_lane; r++) {
143- const int i = r * warp_size + lane;
144- state[col * S_v + i] = s_shard[r];
152+ for (int r = 0 ; r < rows_per_lane; r++) {
153+ const int i = r * warp_size + lane;
154+ state[col * S_v + i] = s_shard[r];
155+ }
145156 }
146157}
147158
148- template <bool KDA>
159+ template <bool KDA, bool keep_intermediates_t >
149160static void launch_gated_delta_net (
150161 const float * q_d, const float * k_d, const float * v_d,
151162 const float * g_d, const float * b_d, const float * s_d,
@@ -169,26 +180,26 @@ static void launch_gated_delta_net(
169180
170181 switch (S_v) {
171182 case 16 :
172- gated_delta_net_cuda<16 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
183+ gated_delta_net_cuda<16 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
173184 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
174185 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
175186 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
176187 break ;
177188 case 32 :
178- gated_delta_net_cuda<32 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
189+ gated_delta_net_cuda<32 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
179190 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
180191 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
181192 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
182193 break ;
183194 case 64 : {
184- gated_delta_net_cuda<64 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
195+ gated_delta_net_cuda<64 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
185196 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
186197 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
187198 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
188199 break ;
189200 }
190201 case 128 : {
191- gated_delta_net_cuda<128 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
202+ gated_delta_net_cuda<128 , KDA, keep_intermediates_t ><<<grid_dims, block_dims, 0 , stream>>> (
192203 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
193204 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
194205 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
@@ -261,13 +272,27 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
261272
262273 cudaStream_t stream = ctx.stream ();
263274
275+ const bool keep_intermediates = (((const int32_t *) dst->op_params )[0 ] != 0 );
276+
264277 if (kda) {
265- launch_gated_delta_net<true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
266- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
267- sb1, sb2, sb3, neqk1, rq3, scale, stream);
278+ if (keep_intermediates) {
279+ launch_gated_delta_net<true , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
280+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
281+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
282+ } else {
283+ launch_gated_delta_net<true , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
284+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
285+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
286+ }
268287 } else {
269- launch_gated_delta_net<false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
270- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
271- sb1, sb2, sb3, neqk1, rq3, scale, stream);
288+ if (keep_intermediates) {
289+ launch_gated_delta_net<false , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
290+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
291+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
292+ } else {
293+ launch_gated_delta_net<false , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
294+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
295+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
296+ }
272297 }
273298}
0 commit comments