Skip to content

Commit 61361ba

Browse files
committed
Optimize MoE scale loading: broadcast when BLOCK_SIZE_K <= group_size
When the Triton tile size fits within a single quantization group, load one scale per N-element instead of per (K, N) element. Reduces scale memory traffic in both GEMM1 and GEMM2 vec-mat kernels. This PR was authored with the assistance of Claude.
1 parent 579bec2 commit 61361ba

1 file changed

Lines changed: 44 additions & 19 deletions

File tree

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,29 @@ def _fused_moe_kernel(
144144
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0)
145145
b = (b >> b_shifter) & 0xF
146146

147-
# Load per-group scales [BLOCK_SIZE_K, BLOCK_SIZE_N]
148-
scale_ptrs = (
149-
B_scale
150-
+ expert_id * stride_bse
151-
+ offs_n[None, :] * stride_bsn
152-
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
153-
)
154-
b_scale = tl.load(
155-
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
156-
).to(tl.float32)
147+
# Load per-group scales and dequantize
148+
if BLOCK_SIZE_K <= group_size:
149+
# All K values in this tile share one scale group — load [1, N]
150+
group_idx = (BLOCK_SIZE_K * k_step) // group_size
151+
scale_ptrs = (
152+
B_scale
153+
+ expert_id * stride_bse
154+
+ offs_n[None, :] * stride_bsn
155+
+ group_idx * stride_bsk
156+
)
157+
b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to(
158+
tl.float32
159+
)
160+
else:
161+
scale_ptrs = (
162+
B_scale
163+
+ expert_id * stride_bse
164+
+ offs_n[None, :] * stride_bsn
165+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
166+
)
167+
b_scale = tl.load(
168+
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
169+
).to(tl.float32)
157170

158171
# Dequantize and accumulate: vector-matrix multiply
159172
b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type)
@@ -252,15 +265,27 @@ def _fused_moe_silu_kernel(
252265
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0)
253266
b = (b >> b_shifter) & 0xF
254267

255-
scale_ptrs = (
256-
B_scale
257-
+ expert_id * stride_bse
258-
+ offs_n[None, :] * stride_bsn
259-
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
260-
)
261-
b_scale = tl.load(
262-
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
263-
).to(tl.float32)
268+
if BLOCK_SIZE_K <= group_size:
269+
group_idx = (BLOCK_SIZE_K * k_step) // group_size
270+
scale_ptrs = (
271+
B_scale
272+
+ expert_id * stride_bse
273+
+ offs_n[None, :] * stride_bsn
274+
+ group_idx * stride_bsk
275+
)
276+
b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to(
277+
tl.float32
278+
)
279+
else:
280+
scale_ptrs = (
281+
B_scale
282+
+ expert_id * stride_bse
283+
+ offs_n[None, :] * stride_bsn
284+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
285+
)
286+
b_scale = tl.load(
287+
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
288+
).to(tl.float32)
264289

265290
b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type)
266291
acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0)

0 commit comments

Comments
 (0)