@@ -144,16 +144,29 @@ def _fused_moe_kernel(
144144 b = tl .load (b_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0 )
145145 b = (b >> b_shifter ) & 0xF
146146
147- # Load per-group scales [BLOCK_SIZE_K, BLOCK_SIZE_N]
148- scale_ptrs = (
149- B_scale
150- + expert_id * stride_bse
151- + offs_n [None , :] * stride_bsn
152- + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
153- )
154- b_scale = tl .load (
155- scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
156- ).to (tl .float32 )
147+ # Load per-group scales and dequantize
148+ if BLOCK_SIZE_K <= group_size :
149+ # All K values in this tile share one scale group — load [1, N]
150+ group_idx = (BLOCK_SIZE_K * k_step ) // group_size
151+ scale_ptrs = (
152+ B_scale
153+ + expert_id * stride_bse
154+ + offs_n [None , :] * stride_bsn
155+ + group_idx * stride_bsk
156+ )
157+ b_scale = tl .load (scale_ptrs , mask = n_mask [None , :], other = 0.0 ).to (
158+ tl .float32
159+ )
160+ else :
161+ scale_ptrs = (
162+ B_scale
163+ + expert_id * stride_bse
164+ + offs_n [None , :] * stride_bsn
165+ + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
166+ )
167+ b_scale = tl .load (
168+ scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
169+ ).to (tl .float32 )
157170
158171 # Dequantize and accumulate: vector-matrix multiply
159172 b_dequant = ((b .to (tl .float32 ) - 8.0 ) * b_scale ).to (compute_type )
@@ -252,15 +265,27 @@ def _fused_moe_silu_kernel(
252265 b = tl .load (b_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0 )
253266 b = (b >> b_shifter ) & 0xF
254267
255- scale_ptrs = (
256- B_scale
257- + expert_id * stride_bse
258- + offs_n [None , :] * stride_bsn
259- + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
260- )
261- b_scale = tl .load (
262- scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
263- ).to (tl .float32 )
268+ if BLOCK_SIZE_K <= group_size :
269+ group_idx = (BLOCK_SIZE_K * k_step ) // group_size
270+ scale_ptrs = (
271+ B_scale
272+ + expert_id * stride_bse
273+ + offs_n [None , :] * stride_bsn
274+ + group_idx * stride_bsk
275+ )
276+ b_scale = tl .load (scale_ptrs , mask = n_mask [None , :], other = 0.0 ).to (
277+ tl .float32
278+ )
279+ else :
280+ scale_ptrs = (
281+ B_scale
282+ + expert_id * stride_bse
283+ + offs_n [None , :] * stride_bsn
284+ + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
285+ )
286+ b_scale = tl .load (
287+ scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
288+ ).to (tl .float32 )
264289
265290 b_dequant = ((b .to (tl .float32 ) - 8.0 ) * b_scale ).to (compute_type )
266291 acc += tl .sum (a [:, None ].to (compute_type ) * b_dequant , axis = 0 )
0 commit comments