11#include " gated_delta_net.cuh"
22
3- template <int S_v, bool KDA>
3+ template <int S_v, bool KDA, bool EMIT >
44__global__ void __launch_bounds__ ((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
55gated_delta_net_cuda(const float * q,
66 const float * k,
@@ -37,7 +37,8 @@ gated_delta_net_cuda(const float * q,
3737 float * attn_data = dst;
3838 float * state = dst + attn_score_elems;
3939
40- const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
40+ const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
41+ const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; // EMIT only
4142 state += state_offset;
4243 curr_state += state_offset + col * S_v;
4344 attn_data += (sequence * n_tokens * H + h_idx) * S_v;
@@ -135,17 +136,30 @@ gated_delta_net_cuda(const float * q,
135136 }
136137
137138 attn_data += S_v * H;
139+
140+ // EMIT: snapshot post-token-t state. Slot t holds state after token t;
141+ // slot T-1 ends up holding the final state (matches CPU emit semantics).
142+ if constexpr (EMIT) {
143+ float * snap_t = (dst + attn_score_elems) + t * state_size_per_snap + state_offset;
144+ #pragma unroll
145+ for (int r = 0 ; r < rows_per_lane; r++) {
146+ const int i = r * warp_size + lane;
147+ snap_t [col * S_v + i] = s_shard[r];
148+ }
149+ }
138150 }
139151
140- // Write state back to global memory (transposed layout)
152+ // Non-emit: write final state. (Emit mode already wrote it as snap T-1.)
153+ if constexpr (!EMIT) {
141154#pragma unroll
142- for (int r = 0 ; r < rows_per_lane; r++) {
143- const int i = r * warp_size + lane;
144- state[col * S_v + i] = s_shard[r];
155+ for (int r = 0 ; r < rows_per_lane; r++) {
156+ const int i = r * warp_size + lane;
157+ state[col * S_v + i] = s_shard[r];
158+ }
145159 }
146160}
147161
148- template <bool KDA>
162+ template <bool KDA, bool EMIT >
149163static void launch_gated_delta_net (
150164 const float * q_d, const float * k_d, const float * v_d,
151165 const float * g_d, const float * b_d, const float * s_d,
@@ -169,26 +183,26 @@ static void launch_gated_delta_net(
169183
170184 switch (S_v) {
171185 case 16 :
172- gated_delta_net_cuda<16 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
186+ gated_delta_net_cuda<16 , KDA, EMIT ><<<grid_dims, block_dims, 0 , stream>>> (
173187 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
174188 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
175189 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
176190 break ;
177191 case 32 :
178- gated_delta_net_cuda<32 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
192+ gated_delta_net_cuda<32 , KDA, EMIT ><<<grid_dims, block_dims, 0 , stream>>> (
179193 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
180194 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
181195 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
182196 break ;
183197 case 64 : {
184- gated_delta_net_cuda<64 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
198+ gated_delta_net_cuda<64 , KDA, EMIT ><<<grid_dims, block_dims, 0 , stream>>> (
185199 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
186200 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
187201 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
188202 break ;
189203 }
190204 case 128 : {
191- gated_delta_net_cuda<128 , KDA><<<grid_dims, block_dims, 0 , stream>>> (
205+ gated_delta_net_cuda<128 , KDA, EMIT ><<<grid_dims, block_dims, 0 , stream>>> (
192206 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
193207 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
194208 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
@@ -261,13 +275,27 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
261275
262276 cudaStream_t stream = ctx.stream ();
263277
278+ const bool emit = (((const int32_t *) dst->op_params )[0 ] != 0 );
279+
264280 if (kda) {
265- launch_gated_delta_net<true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
266- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
267- sb1, sb2, sb3, neqk1, rq3, scale, stream);
281+ if (emit) {
282+ launch_gated_delta_net<true , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
283+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
284+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
285+ } else {
286+ launch_gated_delta_net<true , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
287+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
288+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
289+ }
268290 } else {
269- launch_gated_delta_net<false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
270- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
271- sb1, sb2, sb3, neqk1, rq3, scale, stream);
291+ if (emit) {
292+ launch_gated_delta_net<false , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
293+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
294+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
295+ } else {
296+ launch_gated_delta_net<false , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
297+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
298+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
299+ }
272300 }
273301}
0 commit comments