11#include " gated_delta_net.cuh"
22
33template <int S_v, bool KDA>
4- __global__ void gated_delta_net_cuda (const float * q,
4+ __global__ void __launch_bounds__ ((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
5+ gated_delta_net_cuda(const float * q,
56 const float * k,
67 const float * v,
78 const float * g,
@@ -38,18 +39,19 @@ __global__ void gated_delta_net_cuda(const float * q,
3839
3940 const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
4041 state += state_offset;
41- curr_state += state_offset;
42+ curr_state += state_offset + col * S_v ;
4243 attn_data += (sequence * n_tokens * H + h_idx) * S_v;
4344
4445 constexpr int warp_size = ggml_cuda_get_physical_warp_size () < S_v ? ggml_cuda_get_physical_warp_size () : S_v;
4546 static_assert (S_v % warp_size == 0 , " S_v must be a multiple of warp_size" );
4647 constexpr int rows_per_lane = (S_v + warp_size - 1 ) / warp_size;
4748 float s_shard[rows_per_lane];
4849 // state is stored transposed: M[col][i] = S[i][col], row col is contiguous
50+
4951#pragma unroll
5052 for (int r = 0 ; r < rows_per_lane; r++) {
5153 const int i = r * warp_size + lane;
52- s_shard[r] = curr_state[col * S_v + i];
54+ s_shard[r] = curr_state[i];
5355 }
5456
5557 for (int t = 0 ; t < n_tokens; t++) {
@@ -63,15 +65,24 @@ __global__ void gated_delta_net_cuda(const float * q,
6365
6466 const float beta_val = *beta_t ;
6567
68+ // Cache k and q in registers
69+ float k_reg[rows_per_lane];
70+ float q_reg[rows_per_lane];
71+ #pragma unroll
72+ for (int r = 0 ; r < rows_per_lane; r++) {
73+ const int i = r * warp_size + lane;
74+ k_reg[r] = k_t [i];
75+ q_reg[r] = q_t [i];
76+ }
77+
6678 if constexpr (!KDA) {
6779 const float g_val = expf (*g_t );
6880
6981 // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
7082 float kv_shard = 0 .0f ;
7183#pragma unroll
7284 for (int r = 0 ; r < rows_per_lane; r++) {
73- const int i = r * warp_size + lane;
74- kv_shard += s_shard[r] * k_t [i];
85+ kv_shard += s_shard[r] * k_reg[r];
7586 }
7687 float kv_col = warp_reduce_sum<warp_size>(kv_shard);
7788
@@ -83,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q,
8394 float attn_partial = 0 .0f ;
8495#pragma unroll
8596 for (int r = 0 ; r < rows_per_lane; r++) {
86- const int i = r * warp_size + lane;
87- s_shard[r] = g_val * s_shard[r] + k_t [i] * delta_col;
88- attn_partial += s_shard[r] * q_t [i];
97+ s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col;
98+ attn_partial += s_shard[r] * q_reg[r];
8999 }
90100
91101 float attn_col = warp_reduce_sum<warp_size>(attn_partial);
@@ -99,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q,
99109#pragma unroll
100110 for (int r = 0 ; r < rows_per_lane; r++) {
101111 const int i = r * warp_size + lane;
102- kv_shard += expf (g_t [i]) * s_shard[r] * k_t [i ];
112+ kv_shard += expf (g_t [i]) * s_shard[r] * k_reg[r ];
103113 }
104114
105115 float kv_col = warp_reduce_sum<warp_size>(kv_shard);
@@ -113,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q,
113123#pragma unroll
114124 for (int r = 0 ; r < rows_per_lane; r++) {
115125 const int i = r * warp_size + lane;
116- s_shard[r] = expf (g_t [i]) * s_shard[r] + k_t [i ] * delta_col;
117- attn_partial += s_shard[r] * q_t [i ];
126+ s_shard[r] = expf (g_t [i]) * s_shard[r] + k_reg[r ] * delta_col;
127+ attn_partial += s_shard[r] * q_reg[r ];
118128 }
119129
120130 float attn_col = warp_reduce_sum<warp_size>(attn_partial);
0 commit comments