Skip to content

Commit cb14b06

Browse files
authored
vulkan: optimize ssm_scan (ggml-org#18630)
* vulkan: optimize ssm_scan * fix warp vs subgroup naming
1 parent 55abc39 commit cb14b06

2 files changed

Lines changed: 59 additions & 69 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ struct vk_device_struct {
570570
bool uma;
571571
bool prefer_host_memory;
572572
bool float_controls_rte_fp16;
573+
bool subgroup_basic;
573574
bool subgroup_arithmetic;
574575
bool subgroup_shuffle;
575576
bool subgroup_ballot;
@@ -4301,8 +4302,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
43014302
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
43024303

43034304
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
4304-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
4305-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
4305+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
4306+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
43064307
} else {
43074308
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
43084309
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
@@ -4638,6 +4639,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
46384639
}
46394640
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
46404641

4642+
device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4643+
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
46414644
device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
46424645
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
46434646
#ifdef __APPLE__
@@ -9870,8 +9873,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
98709873

98719874
std::array<uint32_t, 3> elements;
98729875

9873-
const int splitH = 16;
9874-
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
9876+
const uint32_t d_state = src0->ne[0];
9877+
uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
9878+
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
98759879
const uint32_t num_workgroups_y = n_seq;
98769880
elements = { num_workgroups_x, num_workgroups_y, 1 };
98779881

@@ -14777,11 +14781,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1477714781
return false;
1477814782
}
1477914783

14780-
const uint32_t SPLIT_H = 16;
14784+
size_t shmem_size = d_state * sizeof(float);
1478114785

14782-
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
14786+
if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
14787+
return false;
14788+
}
1478314789

14784-
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
14790+
if (!device->subgroup_basic) {
1478514791
return false;
1478614792
}
1478714793

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#version 450
22

33
#extension GL_EXT_control_flow_attributes : require
4+
#extension GL_KHR_shader_subgroup_basic : enable
45
#if USE_SUBGROUP_ADD
56
#extension GL_KHR_shader_subgroup_arithmetic : enable
67
#endif
@@ -9,7 +10,8 @@
910

1011
layout(constant_id = 0) const uint D_STATE = 128;
1112
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
12-
layout(constant_id = 2) const uint SPLIT_H = 16;
13+
14+
const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
1315

1416
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1517

@@ -41,22 +43,28 @@ float softplus(float x) {
4143
}
4244
}
4345

44-
shared float stateC[SPLIT_H * D_STATE];
46+
#if !USE_SUBGROUP_ADD
47+
shared float temp[D_STATE];
48+
#endif
4549

4650
void main() {
47-
const uint tid = gl_LocalInvocationID.x;
48-
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
49-
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
50-
const uint seq_idx = gl_WorkGroupID.y;
51+
const uint subgroup = gl_SubgroupID;
52+
const uint lane = gl_SubgroupInvocationID;
53+
const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
54+
const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
55+
56+
const uint head_idx = subgroup_idx / d_head;
57+
const uint head_off = (subgroup_idx % d_head) * 4;
58+
const uint seq_idx = gl_WorkGroupID.y;
5159

5260
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
5361
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
54-
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
62+
const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
5563
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
5664
const uint A_base_idx = (head_idx * nb31) / 4;
5765
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
5866
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
59-
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
67+
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
6068
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
6169

6270
const uint stride_x = nb12 / 4;
@@ -65,76 +73,52 @@ void main() {
6573
const uint stride_C = nb52 / 4;
6674
const uint stride_y = n_head * d_head;
6775

68-
float state[SPLIT_H];
69-
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
70-
state[j] = s0[s0_base_idx + j * D_STATE + tid];
71-
}
76+
float state[c_factor];
7277

73-
for (uint i = 0; i < n_tok; i++) {
74-
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
78+
[[unroll]] for (uint j = 0; j < c_factor; j++) {
79+
state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
80+
}
7581

76-
const float dA = exp(dt_soft_plus * A[A_base_idx]);
82+
float a = A[A_base_idx];
7783

78-
const float B_val = B[B_base_idx + i * stride_B + tid];
79-
const float C_val = C[C_base_idx + i * stride_C + tid];
84+
for (uint i = 0; i < n_tok; i++) {
85+
float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
8086

81-
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
82-
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
87+
float state_sum = 0.0f;
8388

89+
const float dA = exp(dt_soft_plus * a);
90+
const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
91+
[[unroll]] for (uint j = 0; j < c_factor; j++) {
92+
float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
93+
float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
8494
state[j] = (state[j] * dA) + (B_val * x_dt);
85-
86-
stateC[j * D_STATE + tid] = state[j] * C_val;
95+
state_sum += state[j] * C_val;
8796
}
8897

98+
#if USE_SUBGROUP_ADD
99+
state_sum = subgroupAdd(state_sum);
100+
#else
101+
temp[tid] = state_sum;
89102
barrier();
90-
[[unroll]]
91-
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
92-
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
93-
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
94-
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
95-
stateC[k] += stateC[k + w];
96-
}
103+
[[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
104+
if (lane < s) {
105+
temp[tid] += temp[tid + s];
97106
}
98107
barrier();
99108
}
100-
101-
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
102-
const uint idx = (tid % SUBGROUP_SIZE) +
103-
D_STATE * (tid / SUBGROUP_SIZE) +
104-
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
105-
const uint max_idx = SUBGROUP_SIZE - 1 +
106-
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
107-
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
108-
109-
if (idx < SPLIT_H * D_STATE ||
110-
max_idx < SPLIT_H * D_STATE) {
111-
float sc;
112-
#if USE_SUBGROUP_ADD
113-
sc = stateC[idx];
114-
sc = subgroupAdd(sc);
115-
#else
116-
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
117-
if (idx + offset < SPLIT_H * D_STATE) {
118-
stateC[idx] += stateC[idx + offset];
119-
}
120-
barrier();
121-
}
122-
if (tid % SUBGROUP_SIZE == 0) {
123-
sc = stateC[idx];
124-
}
109+
// get the value from lane 0
110+
state_sum = temp[subgroup * SUBGROUP_SIZE];
111+
barrier();
125112
#endif
126113

127-
if (tid % SUBGROUP_SIZE == 0) {
128-
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
129-
d[y_base_idx + i * stride_y + k] = sc;
130-
}
131-
}
114+
if (lane == 0) {
115+
d[y_base_idx + i * stride_y] = state_sum;
132116
}
133-
134-
barrier();
135117
}
136118

137-
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
138-
d[s_base_idx + j * D_STATE + tid] = state[j];
119+
// write back the state
120+
[[unroll]]
121+
for (int j = 0; j < c_factor; j++) {
122+
d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
139123
}
140124
}

0 commit comments

Comments
 (0)