Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 161 additions & 15 deletions fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
get_sm_version,
process_weight_transpose,
set_weight_attrs,
weight_fully_copied,
Expand Down Expand Up @@ -1909,6 +1910,154 @@ def apply(
)


_SM100_CONFIGS = {
1: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5,
},
2: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
},
4: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4,
},
8: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
},
16: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
},
24: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4,
},
32: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4,
},
48: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
},
64: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
},
96: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
},
128: {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
},
256: {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5,
},
512: {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5,
},
1024: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4,
},
1536: {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3,
},
2048: {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3,
},
3072: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4,
},
4096: {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3,
},
}


class TritonMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE.
Expand All @@ -1929,17 +2078,21 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):

def _get_default_config(self, M: int, E: int) -> dict:
"""
Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
`get_default_config` (bf16/fp16 non-block_shape branch).
See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
GPU-aware heuristic tile config for BF16 MoE.

SM100 (B200): nearest-key lookup from SGLang tuned config
(triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json).
Others: original vLLM-ported heuristic.

M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
E: number of (local) experts.
M: number of tokens (pre-expansion token count).
E: number of (local) experts, SM100 not need.
"""

# Tile sizes scale with batch: small batches are memory-bound
# (favor tall-K tiles), large batches are compute-bound (favor
# large M/N tiles with more warps).
if get_sm_version() >= 100:
best_key = min(_SM100_CONFIGS.keys(), key=lambda x: abs(x - M))

This comment was marked as outdated.

return _SM100_CONFIGS[best_key]

# Default heuristic for all other GPUs (SM80 & SM90) (ported from vLLM)
if M <= 32:
block_m = 16
elif M <= 96:
Expand All @@ -1950,19 +2103,12 @@ def _get_default_config(self, M: int, E: int) -> dict:
block_m = 128

block_n = 64 if M <= 64 else 128

block_k = 64

# Grouping adjacent M-blocks lets them share weight tiles in L2.
# Only helps when there are enough M-blocks per expert to group;
# with many experts each one sees few tokens so grouping is useless.
tokens_per_expert = M // max(E, 1)
group_m = 16 if tokens_per_expert > 128 else 1

# Large batches have enough blocks to saturate the GPU, so we
# use more warps per block to increase arithmetic intensity.
num_warps = 4 if M <= 128 else 8

num_stages = 4 if M <= 32 else 3

return {
Expand Down
Loading
Loading