Skip to content

Commit 25702dd

Browse files
upgrade triton moe config in sm100.
1 parent 529ec9e commit 25702dd

2 files changed

Lines changed: 314 additions & 45 deletions

File tree

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 159 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,17 +1920,169 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
19201920

19211921
def _get_default_config(self, M: int, E: int) -> dict:
19221922
"""
1923-
Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1924-
`get_default_config` (bf16/fp16 non-block_shape branch).
1925-
See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1923+
GPU-aware heuristic tile config for BF16 MoE.
19261924
1927-
M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1925+
SM100 (B200): nearest-key lookup from SGLang tuned config
1926+
(triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json).
1927+
Others: original vLLM-ported heuristic.
1928+
1929+
M: number of tokens (pre-expansion token count).
19281930
E: number of (local) experts.
19291931
"""
1932+
from fastdeploy.model_executor.utils import get_sm_version
1933+
1934+
if get_sm_version() >= 100:
1935+
# SM100 (B200): use SGLang tuned lookup, nearest key by abs diff
1936+
_SM100_CONFIGS = {
1937+
1: {
1938+
"BLOCK_SIZE_M": 16,
1939+
"BLOCK_SIZE_N": 64,
1940+
"BLOCK_SIZE_K": 64,
1941+
"GROUP_SIZE_M": 32,
1942+
"num_warps": 4,
1943+
"num_stages": 5,
1944+
},
1945+
2: {
1946+
"BLOCK_SIZE_M": 16,
1947+
"BLOCK_SIZE_N": 64,
1948+
"BLOCK_SIZE_K": 128,
1949+
"GROUP_SIZE_M": 32,
1950+
"num_warps": 4,
1951+
"num_stages": 3,
1952+
},
1953+
4: {
1954+
"BLOCK_SIZE_M": 16,
1955+
"BLOCK_SIZE_N": 64,
1956+
"BLOCK_SIZE_K": 128,
1957+
"GROUP_SIZE_M": 64,
1958+
"num_warps": 4,
1959+
"num_stages": 4,
1960+
},
1961+
8: {
1962+
"BLOCK_SIZE_M": 16,
1963+
"BLOCK_SIZE_N": 64,
1964+
"BLOCK_SIZE_K": 128,
1965+
"GROUP_SIZE_M": 32,
1966+
"num_warps": 4,
1967+
"num_stages": 3,
1968+
},
1969+
16: {
1970+
"BLOCK_SIZE_M": 16,
1971+
"BLOCK_SIZE_N": 64,
1972+
"BLOCK_SIZE_K": 128,
1973+
"GROUP_SIZE_M": 1,
1974+
"num_warps": 4,
1975+
"num_stages": 3,
1976+
},
1977+
24: {
1978+
"BLOCK_SIZE_M": 16,
1979+
"BLOCK_SIZE_N": 128,
1980+
"BLOCK_SIZE_K": 128,
1981+
"GROUP_SIZE_M": 16,
1982+
"num_warps": 4,
1983+
"num_stages": 4,
1984+
},
1985+
32: {
1986+
"BLOCK_SIZE_M": 16,
1987+
"BLOCK_SIZE_N": 64,
1988+
"BLOCK_SIZE_K": 128,
1989+
"GROUP_SIZE_M": 16,
1990+
"num_warps": 4,
1991+
"num_stages": 4,
1992+
},
1993+
48: {
1994+
"BLOCK_SIZE_M": 16,
1995+
"BLOCK_SIZE_N": 64,
1996+
"BLOCK_SIZE_K": 128,
1997+
"GROUP_SIZE_M": 1,
1998+
"num_warps": 4,
1999+
"num_stages": 4,
2000+
},
2001+
64: {
2002+
"BLOCK_SIZE_M": 16,
2003+
"BLOCK_SIZE_N": 64,
2004+
"BLOCK_SIZE_K": 128,
2005+
"GROUP_SIZE_M": 1,
2006+
"num_warps": 4,
2007+
"num_stages": 4,
2008+
},
2009+
96: {
2010+
"BLOCK_SIZE_M": 16,
2011+
"BLOCK_SIZE_N": 64,
2012+
"BLOCK_SIZE_K": 128,
2013+
"GROUP_SIZE_M": 1,
2014+
"num_warps": 4,
2015+
"num_stages": 3,
2016+
},
2017+
128: {
2018+
"BLOCK_SIZE_M": 16,
2019+
"BLOCK_SIZE_N": 64,
2020+
"BLOCK_SIZE_K": 128,
2021+
"GROUP_SIZE_M": 1,
2022+
"num_warps": 4,
2023+
"num_stages": 3,
2024+
},
2025+
256: {
2026+
"BLOCK_SIZE_M": 32,
2027+
"BLOCK_SIZE_N": 128,
2028+
"BLOCK_SIZE_K": 64,
2029+
"GROUP_SIZE_M": 1,
2030+
"num_warps": 4,
2031+
"num_stages": 5,
2032+
},
2033+
512: {
2034+
"BLOCK_SIZE_M": 64,
2035+
"BLOCK_SIZE_N": 256,
2036+
"BLOCK_SIZE_K": 64,
2037+
"GROUP_SIZE_M": 1,
2038+
"num_warps": 8,
2039+
"num_stages": 5,
2040+
},
2041+
1024: {
2042+
"BLOCK_SIZE_M": 128,
2043+
"BLOCK_SIZE_N": 256,
2044+
"BLOCK_SIZE_K": 64,
2045+
"GROUP_SIZE_M": 1,
2046+
"num_warps": 8,
2047+
"num_stages": 4,
2048+
},
2049+
1536: {
2050+
"BLOCK_SIZE_M": 256,
2051+
"BLOCK_SIZE_N": 256,
2052+
"BLOCK_SIZE_K": 64,
2053+
"GROUP_SIZE_M": 1,
2054+
"num_warps": 8,
2055+
"num_stages": 3,
2056+
},
2057+
2048: {
2058+
"BLOCK_SIZE_M": 256,
2059+
"BLOCK_SIZE_N": 256,
2060+
"BLOCK_SIZE_K": 64,
2061+
"GROUP_SIZE_M": 1,
2062+
"num_warps": 8,
2063+
"num_stages": 3,
2064+
},
2065+
3072: {
2066+
"BLOCK_SIZE_M": 128,
2067+
"BLOCK_SIZE_N": 256,
2068+
"BLOCK_SIZE_K": 64,
2069+
"GROUP_SIZE_M": 1,
2070+
"num_warps": 8,
2071+
"num_stages": 4,
2072+
},
2073+
4096: {
2074+
"BLOCK_SIZE_M": 256,
2075+
"BLOCK_SIZE_N": 256,
2076+
"BLOCK_SIZE_K": 64,
2077+
"GROUP_SIZE_M": 1,
2078+
"num_warps": 8,
2079+
"num_stages": 3,
2080+
},
2081+
}
2082+
best_key = min(_SM100_CONFIGS.keys(), key=lambda x: abs(x - M))
2083+
return _SM100_CONFIGS[best_key]
19302084

1931-
# Tile sizes scale with batch: small batches are memory-bound
1932-
# (favor tall-K tiles), large batches are compute-bound (favor
1933-
# large M/N tiles with more warps).
2085+
# Default heuristic for all other GPUs (ported from vLLM)
19342086
if M <= 32:
19352087
block_m = 16
19362088
elif M <= 96:
@@ -1941,19 +2093,12 @@ def _get_default_config(self, M: int, E: int) -> dict:
19412093
block_m = 128
19422094

19432095
block_n = 64 if M <= 64 else 128
1944-
19452096
block_k = 64
19462097

1947-
# Grouping adjacent M-blocks lets them share weight tiles in L2.
1948-
# Only helps when there are enough M-blocks per expert to group;
1949-
# with many experts each one sees few tokens so grouping is useless.
19502098
tokens_per_expert = M // max(E, 1)
19512099
group_m = 16 if tokens_per_expert > 128 else 1
19522100

1953-
# Large batches have enough blocks to saturate the GPU, so we
1954-
# use more warps per block to increase arithmetic intensity.
19552101
num_warps = 4 if M <= 128 else 8
1956-
19572102
num_stages = 4 if M <= 32 else 3
19582103

19592104
return {

0 commit comments

Comments
 (0)