Skip to content

Commit 4272adb

Browse files
jeffbolznvLostRuins
authored andcommitted
cherry pick f1768d8
1 parent 7e134c4 commit 4272adb

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ void main() {
101101
const uint lane = gl_SubgroupInvocationID;
102102

103103
float probs[experts_per_thread];
104+
[[unroll]]
105+
for (int i = 0; i < experts_per_thread; i++) {
106+
probs[i] = -INFINITY;
107+
}
104108

105109
[[unroll]]
106110
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
@@ -112,8 +116,9 @@ void main() {
112116
softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
113117
} else if (gating_func == GATING_FUNC_SIGMOID) {
114118
[[unroll]]
115-
for (int i = 0; i < experts_per_thread; i++) {
116-
probs[i] = 1.f / (1.f + exp(-probs[i]));
119+
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
120+
const uint expert = i + lane;
121+
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
117122
}
118123
}
119124

@@ -150,11 +155,11 @@ void main() {
150155
uint max_expert = lane;
151156

152157
[[unroll]]
153-
for (int i = 1; i < experts_per_thread; i++) {
154-
const uint expert = lane + i * WARP_SIZE;
155-
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
156-
max_val = probs[i];
157-
max_val_s = selection_probs[i];
158+
for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
159+
const uint expert = i + lane;
160+
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
161+
max_val = probs[i / WARP_SIZE];
162+
max_val_s = selection_probs[i / WARP_SIZE];
158163
max_expert = expert;
159164
}
160165
}

0 commit comments

Comments
 (0)