Skip to content

Commit b91661a

Browse files
reeselevinersenthilkumar6
authored andcommitted
ggml-webgpu : extend GDN for K>1 (ggml-org#23299)
1 parent fda79a8 commit b91661a

2 files changed

Lines changed: 22 additions & 4 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
12341234
const uint32_t h = (uint32_t) src2->ne[1];
12351235
const uint32_t n_tokens = (uint32_t) src2->ne[2];
12361236
const uint32_t n_seqs = (uint32_t) src2->ne[3];
1237+
const uint32_t K = (uint32_t) src5->ne[1];
12371238
const float scale = 1.0f / sqrtf((float) s_v);
12381239
uint32_t scale_u32;
12391240
memcpy(&scale_u32, &scale, sizeof(scale_u32));
@@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
12581259

12591260
(uint32_t) src0->ne[1],
12601261
(uint32_t) (src2->ne[3] / src0->ne[3]),
1262+
K,
12611263
scale_u32,
12621264
};
12631265

ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct Params {
3939

4040
neq1: u32,
4141
rq3: u32,
42+
K: u32,
4243
scale: f32,
4344
};
4445

@@ -62,11 +63,14 @@ fn main(
6263
let iq3 = seq_id / params.rq3;
6364

6465
let state_size = S_V * S_V;
65-
let state_base = (seq_id * params.h + head_id) * state_size;
66+
let state_in_base = (seq_id * params.K * params.h + head_id) * state_size;
67+
let state_out_base = (seq_id * params.h + head_id) * state_size;
68+
let state_size_per_snap = state_size * params.h * params.n_seqs;
69+
let shift = i32(params.n_tokens) - i32(params.K);
6670

6771
var state: array<f32, S_V>;
6872
for (var i = 0u; i < S_V; i++) {
69-
state[i] = src_state[state_base + col * S_V + i];
73+
state[i] = src_state[state_in_base + col * S_V + i];
7074
}
7175

7276
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
@@ -123,10 +127,22 @@ fn main(
123127
dst[attn_off + col] = attn_col * params.scale;
124128
attn_off += S_V * params.h;
125129

130+
if (params.K > 1u) {
131+
let target_slot = i32(t) - shift;
132+
if (target_slot >= 0 && target_slot < i32(params.K)) {
133+
let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base;
134+
for (var i = 0u; i < S_V; i++) {
135+
dst[slot_base + col * S_V + i] = state[i];
136+
}
137+
}
138+
}
139+
126140
workgroupBarrier();
127141
}
128142

129-
for (var i = 0u; i < S_V; i++) {
130-
dst[params.s_off + state_base + col * S_V + i] = state[i];
143+
if (params.K == 1u) {
144+
for (var i = 0u; i < S_V; i++) {
145+
dst[params.s_off + state_out_base + col * S_V + i] = state[i];
146+
}
131147
}
132148
}

0 commit comments

Comments
 (0)