Skip to content

Commit 34818ea

Browse files
authored
CUDA: GDN hide memory latency (#20537)
1 parent 9e2e219 commit 34818ea

1 file changed

Lines changed: 21 additions & 11 deletions

File tree

ggml/src/ggml-cuda/gated_delta_net.cu

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#include "gated_delta_net.cuh"
22

33
template <int S_v, bool KDA>
4-
__global__ void gated_delta_net_cuda(const float * q,
4+
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
5+
gated_delta_net_cuda(const float * q,
56
const float * k,
67
const float * v,
78
const float * g,
@@ -38,18 +39,19 @@ __global__ void gated_delta_net_cuda(const float * q,
3839

3940
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
4041
state += state_offset;
41-
curr_state += state_offset;
42+
curr_state += state_offset + col * S_v;
4243
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
4344

4445
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
4546
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
4647
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
4748
float s_shard[rows_per_lane];
4849
// state is stored transposed: M[col][i] = S[i][col], row col is contiguous
50+
4951
#pragma unroll
5052
for (int r = 0; r < rows_per_lane; r++) {
5153
const int i = r * warp_size + lane;
52-
s_shard[r] = curr_state[col * S_v + i];
54+
s_shard[r] = curr_state[i];
5355
}
5456

5557
for (int t = 0; t < n_tokens; t++) {
@@ -63,15 +65,24 @@ __global__ void gated_delta_net_cuda(const float * q,
6365

6466
const float beta_val = *beta_t;
6567

68+
// Cache k and q in registers
69+
float k_reg[rows_per_lane];
70+
float q_reg[rows_per_lane];
71+
#pragma unroll
72+
for (int r = 0; r < rows_per_lane; r++) {
73+
const int i = r * warp_size + lane;
74+
k_reg[r] = k_t[i];
75+
q_reg[r] = q_t[i];
76+
}
77+
6678
if constexpr (!KDA) {
6779
const float g_val = expf(*g_t);
6880

6981
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
7082
float kv_shard = 0.0f;
7183
#pragma unroll
7284
for (int r = 0; r < rows_per_lane; r++) {
73-
const int i = r * warp_size + lane;
74-
kv_shard += s_shard[r] * k_t[i];
85+
kv_shard += s_shard[r] * k_reg[r];
7586
}
7687
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
7788

@@ -83,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q,
8394
float attn_partial = 0.0f;
8495
#pragma unroll
8596
for (int r = 0; r < rows_per_lane; r++) {
86-
const int i = r * warp_size + lane;
87-
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
88-
attn_partial += s_shard[r] * q_t[i];
97+
s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col;
98+
attn_partial += s_shard[r] * q_reg[r];
8999
}
90100

91101
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
@@ -99,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q,
99109
#pragma unroll
100110
for (int r = 0; r < rows_per_lane; r++) {
101111
const int i = r * warp_size + lane;
102-
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
112+
kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r];
103113
}
104114

105115
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
@@ -113,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q,
113123
#pragma unroll
114124
for (int r = 0; r < rows_per_lane; r++) {
115125
const int i = r * warp_size + lane;
116-
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
117-
attn_partial += s_shard[r] * q_t[i];
126+
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col;
127+
attn_partial += s_shard[r] * q_reg[r];
118128
}
119129

120130
float attn_col = warp_reduce_sum<warp_size>(attn_partial);

0 commit comments

Comments
 (0)