11#version 450
22
33#extension GL_EXT_control_flow_attributes : require
4+ #extension GL_KHR_shader_subgroup_basic : enable
45#if USE_SUBGROUP_ADD
56#extension GL_KHR_shader_subgroup_arithmetic : enable
67#endif
910
1011layout(constant_id = 0) const uint D_STATE = 128;
1112layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
12- layout(constant_id = 2) const uint SPLIT_H = 16;
13+
14+ const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
1315
1416layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1517
@@ -41,22 +43,28 @@ float softplus(float x) {
4143 }
4244}
4345
44- shared float stateC[SPLIT_H * D_STATE];
46+ #if !USE_SUBGROUP_ADD
47+ shared float temp[D_STATE];
48+ #endif
4549
4650void main() {
47- const uint tid = gl_LocalInvocationID.x;
48- const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
49- const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
50- const uint seq_idx = gl_WorkGroupID.y;
51+ const uint subgroup = gl_SubgroupID;
52+ const uint lane = gl_SubgroupInvocationID;
53+ const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
54+ const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
55+
56+ const uint head_idx = subgroup_idx / d_head;
57+ const uint head_off = (subgroup_idx % d_head) * 4;
58+ const uint seq_idx = gl_WorkGroupID.y;
5159
5260 const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
5361 const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
54- const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
62+ const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
5563 const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
5664 const uint A_base_idx = (head_idx * nb31) / 4;
5765 const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
5866 const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
59- const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H ;
67+ const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx ;
6068 const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
6169
6270 const uint stride_x = nb12 / 4;
@@ -65,76 +73,52 @@ void main() {
6573 const uint stride_C = nb52 / 4;
6674 const uint stride_y = n_head * d_head;
6775
68- float state[SPLIT_H];
69- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
70- state[j] = s0[s0_base_idx + j * D_STATE + tid];
71- }
76+ float state[c_factor];
7277
73- for (uint i = 0; i < n_tok; i++) {
74- const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
78+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
79+ state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
80+ }
7581
76- const float dA = exp(dt_soft_plus * A[A_base_idx]) ;
82+ float a = A[A_base_idx];
7783
78- const float B_val = B[B_base_idx + i * stride_B + tid];
79- const float C_val = C[C_base_idx + i * stride_C + tid] ;
84+ for (uint i = 0; i < n_tok; i++) {
85+ float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]) ;
8086
81- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
82- const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
87+ float state_sum = 0.0f;
8388
89+ const float dA = exp(dt_soft_plus * a);
90+ const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
91+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
92+ float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
93+ float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
8494 state[j] = (state[j] * dA) + (B_val * x_dt);
85-
86- stateC[j * D_STATE + tid] = state[j] * C_val;
95+ state_sum += state[j] * C_val;
8796 }
8897
98+ #if USE_SUBGROUP_ADD
99+ state_sum = subgroupAdd(state_sum);
100+ #else
101+ temp[tid] = state_sum;
89102 barrier();
90- [[unroll]]
91- for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
92- [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
93- const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
94- if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
95- stateC[k] += stateC[k + w];
96- }
103+ [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
104+ if (lane < s) {
105+ temp[tid] += temp[tid + s];
97106 }
98107 barrier();
99108 }
100-
101- [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
102- const uint idx = (tid % SUBGROUP_SIZE) +
103- D_STATE * (tid / SUBGROUP_SIZE) +
104- j * D_STATE * (D_STATE / SUBGROUP_SIZE);
105- const uint max_idx = SUBGROUP_SIZE - 1 +
106- D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
107- j * D_STATE * (D_STATE / SUBGROUP_SIZE);
108-
109- if (idx < SPLIT_H * D_STATE ||
110- max_idx < SPLIT_H * D_STATE) {
111- float sc;
112- #if USE_SUBGROUP_ADD
113- sc = stateC[idx];
114- sc = subgroupAdd(sc);
115- #else
116- [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
117- if (idx + offset < SPLIT_H * D_STATE) {
118- stateC[idx] += stateC[idx + offset];
119- }
120- barrier();
121- }
122- if (tid % SUBGROUP_SIZE == 0) {
123- sc = stateC[idx];
124- }
109+ // get the value from lane 0
110+ state_sum = temp[subgroup * SUBGROUP_SIZE];
111+ barrier();
125112#endif
126113
127- if (tid % SUBGROUP_SIZE == 0) {
128- const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
129- d[y_base_idx + i * stride_y + k] = sc;
130- }
131- }
114+ if (lane == 0) {
115+ d[y_base_idx + i * stride_y] = state_sum;
132116 }
133-
134- barrier();
135117 }
136118
137- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
138- d[s_base_idx + j * D_STATE + tid] = state[j];
119+ // write back the state
120+ [[unroll]]
121+ for (int j = 0; j < c_factor; j++) {
122+ d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
139123 }
140124}
0 commit comments