Skip to content

Commit a816e85

Browse files
upgrade triton moe config in sm100.
1 parent ac24fcc commit a816e85

2 files changed

Lines changed: 314 additions & 45 deletions

File tree

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 159 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,17 +1879,169 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
18791879

18801880
def _get_default_config(self, M: int, E: int) -> dict:
18811881
"""
1882-
Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1883-
`get_default_config` (bf16/fp16 non-block_shape branch).
1884-
See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1882+
GPU-aware heuristic tile config for BF16 MoE.
18851883
1886-
M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1884+
SM100 (B200): nearest-key lookup from SGLang tuned config
1885+
(triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json).
1886+
Others: original vLLM-ported heuristic.
1887+
1888+
M: number of tokens (pre-expansion token count).
18871889
E: number of (local) experts.
18881890
"""
1891+
from fastdeploy.model_executor.utils import get_sm_version
1892+
1893+
if get_sm_version() >= 100:
1894+
# SM100 (B200): use SGLang tuned lookup, nearest key by abs diff
1895+
_SM100_CONFIGS = {
1896+
1: {
1897+
"BLOCK_SIZE_M": 16,
1898+
"BLOCK_SIZE_N": 64,
1899+
"BLOCK_SIZE_K": 64,
1900+
"GROUP_SIZE_M": 32,
1901+
"num_warps": 4,
1902+
"num_stages": 5,
1903+
},
1904+
2: {
1905+
"BLOCK_SIZE_M": 16,
1906+
"BLOCK_SIZE_N": 64,
1907+
"BLOCK_SIZE_K": 128,
1908+
"GROUP_SIZE_M": 32,
1909+
"num_warps": 4,
1910+
"num_stages": 3,
1911+
},
1912+
4: {
1913+
"BLOCK_SIZE_M": 16,
1914+
"BLOCK_SIZE_N": 64,
1915+
"BLOCK_SIZE_K": 128,
1916+
"GROUP_SIZE_M": 64,
1917+
"num_warps": 4,
1918+
"num_stages": 4,
1919+
},
1920+
8: {
1921+
"BLOCK_SIZE_M": 16,
1922+
"BLOCK_SIZE_N": 64,
1923+
"BLOCK_SIZE_K": 128,
1924+
"GROUP_SIZE_M": 32,
1925+
"num_warps": 4,
1926+
"num_stages": 3,
1927+
},
1928+
16: {
1929+
"BLOCK_SIZE_M": 16,
1930+
"BLOCK_SIZE_N": 64,
1931+
"BLOCK_SIZE_K": 128,
1932+
"GROUP_SIZE_M": 1,
1933+
"num_warps": 4,
1934+
"num_stages": 3,
1935+
},
1936+
24: {
1937+
"BLOCK_SIZE_M": 16,
1938+
"BLOCK_SIZE_N": 128,
1939+
"BLOCK_SIZE_K": 128,
1940+
"GROUP_SIZE_M": 16,
1941+
"num_warps": 4,
1942+
"num_stages": 4,
1943+
},
1944+
32: {
1945+
"BLOCK_SIZE_M": 16,
1946+
"BLOCK_SIZE_N": 64,
1947+
"BLOCK_SIZE_K": 128,
1948+
"GROUP_SIZE_M": 16,
1949+
"num_warps": 4,
1950+
"num_stages": 4,
1951+
},
1952+
48: {
1953+
"BLOCK_SIZE_M": 16,
1954+
"BLOCK_SIZE_N": 64,
1955+
"BLOCK_SIZE_K": 128,
1956+
"GROUP_SIZE_M": 1,
1957+
"num_warps": 4,
1958+
"num_stages": 4,
1959+
},
1960+
64: {
1961+
"BLOCK_SIZE_M": 16,
1962+
"BLOCK_SIZE_N": 64,
1963+
"BLOCK_SIZE_K": 128,
1964+
"GROUP_SIZE_M": 1,
1965+
"num_warps": 4,
1966+
"num_stages": 4,
1967+
},
1968+
96: {
1969+
"BLOCK_SIZE_M": 16,
1970+
"BLOCK_SIZE_N": 64,
1971+
"BLOCK_SIZE_K": 128,
1972+
"GROUP_SIZE_M": 1,
1973+
"num_warps": 4,
1974+
"num_stages": 3,
1975+
},
1976+
128: {
1977+
"BLOCK_SIZE_M": 16,
1978+
"BLOCK_SIZE_N": 64,
1979+
"BLOCK_SIZE_K": 128,
1980+
"GROUP_SIZE_M": 1,
1981+
"num_warps": 4,
1982+
"num_stages": 3,
1983+
},
1984+
256: {
1985+
"BLOCK_SIZE_M": 32,
1986+
"BLOCK_SIZE_N": 128,
1987+
"BLOCK_SIZE_K": 64,
1988+
"GROUP_SIZE_M": 1,
1989+
"num_warps": 4,
1990+
"num_stages": 5,
1991+
},
1992+
512: {
1993+
"BLOCK_SIZE_M": 64,
1994+
"BLOCK_SIZE_N": 256,
1995+
"BLOCK_SIZE_K": 64,
1996+
"GROUP_SIZE_M": 1,
1997+
"num_warps": 8,
1998+
"num_stages": 5,
1999+
},
2000+
1024: {
2001+
"BLOCK_SIZE_M": 128,
2002+
"BLOCK_SIZE_N": 256,
2003+
"BLOCK_SIZE_K": 64,
2004+
"GROUP_SIZE_M": 1,
2005+
"num_warps": 8,
2006+
"num_stages": 4,
2007+
},
2008+
1536: {
2009+
"BLOCK_SIZE_M": 256,
2010+
"BLOCK_SIZE_N": 256,
2011+
"BLOCK_SIZE_K": 64,
2012+
"GROUP_SIZE_M": 1,
2013+
"num_warps": 8,
2014+
"num_stages": 3,
2015+
},
2016+
2048: {
2017+
"BLOCK_SIZE_M": 256,
2018+
"BLOCK_SIZE_N": 256,
2019+
"BLOCK_SIZE_K": 64,
2020+
"GROUP_SIZE_M": 1,
2021+
"num_warps": 8,
2022+
"num_stages": 3,
2023+
},
2024+
3072: {
2025+
"BLOCK_SIZE_M": 128,
2026+
"BLOCK_SIZE_N": 256,
2027+
"BLOCK_SIZE_K": 64,
2028+
"GROUP_SIZE_M": 1,
2029+
"num_warps": 8,
2030+
"num_stages": 4,
2031+
},
2032+
4096: {
2033+
"BLOCK_SIZE_M": 256,
2034+
"BLOCK_SIZE_N": 256,
2035+
"BLOCK_SIZE_K": 64,
2036+
"GROUP_SIZE_M": 1,
2037+
"num_warps": 8,
2038+
"num_stages": 3,
2039+
},
2040+
}
2041+
best_key = min(_SM100_CONFIGS.keys(), key=lambda x: abs(x - M))
2042+
return _SM100_CONFIGS[best_key]
18892043

1890-
# Tile sizes scale with batch: small batches are memory-bound
1891-
# (favor tall-K tiles), large batches are compute-bound (favor
1892-
# large M/N tiles with more warps).
2044+
# Default heuristic for all other GPUs (ported from vLLM)
18932045
if M <= 32:
18942046
block_m = 16
18952047
elif M <= 96:
@@ -1900,19 +2052,12 @@ def _get_default_config(self, M: int, E: int) -> dict:
19002052
block_m = 128
19012053

19022054
block_n = 64 if M <= 64 else 128
1903-
19042055
block_k = 64
19052056

1906-
# Grouping adjacent M-blocks lets them share weight tiles in L2.
1907-
# Only helps when there are enough M-blocks per expert to group;
1908-
# with many experts each one sees few tokens so grouping is useless.
19092057
tokens_per_expert = M // max(E, 1)
19102058
group_m = 16 if tokens_per_expert > 128 else 1
19112059

1912-
# Large batches have enough blocks to saturate the GPU, so we
1913-
# use more warps per block to increase arithmetic intensity.
19142060
num_warps = 4 if M <= 128 else 8
1915-
19162061
num_stages = 4 if M <= 32 else 3
19172062

19182063
return {

0 commit comments

Comments
 (0)