@@ -101,6 +101,10 @@ void main() {
101101 const uint lane = gl_SubgroupInvocationID;
102102
103103 float probs[experts_per_thread];
104+ [[unroll]]
105+ for (int i = 0; i < experts_per_thread; i++) {
106+ probs[i] = -INFINITY;
107+ }
104108
105109 [[unroll]]
106110 for (uint i = 0; i < n_experts; i += WARP_SIZE) {
@@ -112,8 +116,9 @@ void main() {
112116 softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
113117 } else if (gating_func == GATING_FUNC_SIGMOID) {
114118 [[unroll]]
115- for (int i = 0; i < experts_per_thread; i++) {
116- probs[i] = 1.f / (1.f + exp(-probs[i]));
119+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
120+ const uint expert = i + lane;
121+ probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
117122 }
118123 }
119124
@@ -150,11 +155,11 @@ void main() {
150155 uint max_expert = lane;
151156
152157 [[unroll]]
153- for (int i = 1 ; i < experts_per_thread ; i++ ) {
154- const uint expert = lane + i * WARP_SIZE ;
155- if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
156- max_val = probs[i];
157- max_val_s = selection_probs[i];
158+ for (uint i = WARP_SIZE ; i < n_experts ; i += WARP_SIZE ) {
159+ const uint expert = i + lane ;
160+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE ] > max_val_s) {
161+ max_val = probs[i / WARP_SIZE ];
162+ max_val_s = selection_probs[i / WARP_SIZE ];
158163 max_expert = expert;
159164 }
160165 }
0 commit comments