Skip to content

Commit fb123b0

Browse files
upgrade triton moe config.
1 parent ac24fcc commit fb123b0

1 file changed

Lines changed: 308 additions & 41 deletions

File tree

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 308 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,51 +1879,318 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
18791879

18801880
def _get_default_config(self, M: int, E: int) -> dict:
18811881
"""
1882-
Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1883-
`get_default_config` (bf16/fp16 non-block_shape branch).
1884-
See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1882+
GPU-aware heuristic tile config for BF16 MoE.
18851883
1886-
M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1887-
E: number of (local) experts.
1888-
"""
1889-
1890-
# Tile sizes scale with batch: small batches are memory-bound
1891-
# (favor tall-K tiles), large batches are compute-bound (favor
1892-
# large M/N tiles with more warps).
1893-
if M <= 32:
1894-
block_m = 16
1895-
elif M <= 96:
1896-
block_m = 32
1897-
elif M <= 512:
1898-
block_m = 64
1899-
else:
1900-
block_m = 128
1901-
1902-
block_n = 64 if M <= 64 else 128
1903-
1904-
block_k = 64
1884+
Derived from SGLang's per-device tuned JSON configs for E=64, N=1856:
1885+
- SM100 (B200): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json
1886+
- SM90 (H100): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json
19051887
1906-
# Grouping adjacent M-blocks lets them share weight tiles in L2.
1907-
# Only helps when there are enough M-blocks per expert to group;
1908-
# with many experts each one sees few tokens so grouping is useless.
1909-
tokens_per_expert = M // max(E, 1)
1910-
group_m = 16 if tokens_per_expert > 128 else 1
1888+
Config selection mirrors SGLang's try_get_optimal_moe_config:
1889+
pick the entry whose key is closest to M by absolute difference.
19111890
1912-
# Large batches have enough blocks to saturate the GPU, so we
1913-
# use more warps per block to increase arithmetic intensity.
1914-
num_warps = 4 if M <= 128 else 8
1915-
1916-
num_stages = 4 if M <= 32 else 3
1917-
1918-
return {
1919-
"BLOCK_SIZE_M": block_m,
1920-
"BLOCK_SIZE_N": block_n,
1921-
"BLOCK_SIZE_K": block_k,
1922-
"GROUP_SIZE_M": group_m,
1923-
"num_warps": num_warps,
1924-
"num_stages": num_stages,
1891+
M: number of tokens (pre-expansion token count).
1892+
E: number of (local) experts.
1893+
"""
1894+
from fastdeploy.model_executor.utils import get_sm_version
1895+
1896+
# SM100=B200 (sm_version>=100), SM90=H100, default to H100 on unknown GPU
1897+
_SM100_CONFIGS = {
1898+
1: {
1899+
"BLOCK_SIZE_M": 16,
1900+
"BLOCK_SIZE_N": 64,
1901+
"BLOCK_SIZE_K": 64,
1902+
"GROUP_SIZE_M": 32,
1903+
"num_warps": 4,
1904+
"num_stages": 5,
1905+
},
1906+
2: {
1907+
"BLOCK_SIZE_M": 16,
1908+
"BLOCK_SIZE_N": 64,
1909+
"BLOCK_SIZE_K": 128,
1910+
"GROUP_SIZE_M": 32,
1911+
"num_warps": 4,
1912+
"num_stages": 3,
1913+
},
1914+
4: {
1915+
"BLOCK_SIZE_M": 16,
1916+
"BLOCK_SIZE_N": 64,
1917+
"BLOCK_SIZE_K": 128,
1918+
"GROUP_SIZE_M": 64,
1919+
"num_warps": 4,
1920+
"num_stages": 4,
1921+
},
1922+
8: {
1923+
"BLOCK_SIZE_M": 16,
1924+
"BLOCK_SIZE_N": 64,
1925+
"BLOCK_SIZE_K": 128,
1926+
"GROUP_SIZE_M": 32,
1927+
"num_warps": 4,
1928+
"num_stages": 3,
1929+
},
1930+
16: {
1931+
"BLOCK_SIZE_M": 16,
1932+
"BLOCK_SIZE_N": 64,
1933+
"BLOCK_SIZE_K": 128,
1934+
"GROUP_SIZE_M": 1,
1935+
"num_warps": 4,
1936+
"num_stages": 3,
1937+
},
1938+
24: {
1939+
"BLOCK_SIZE_M": 16,
1940+
"BLOCK_SIZE_N": 128,
1941+
"BLOCK_SIZE_K": 128,
1942+
"GROUP_SIZE_M": 16,
1943+
"num_warps": 4,
1944+
"num_stages": 4,
1945+
},
1946+
32: {
1947+
"BLOCK_SIZE_M": 16,
1948+
"BLOCK_SIZE_N": 64,
1949+
"BLOCK_SIZE_K": 128,
1950+
"GROUP_SIZE_M": 16,
1951+
"num_warps": 4,
1952+
"num_stages": 4,
1953+
},
1954+
48: {
1955+
"BLOCK_SIZE_M": 16,
1956+
"BLOCK_SIZE_N": 64,
1957+
"BLOCK_SIZE_K": 128,
1958+
"GROUP_SIZE_M": 1,
1959+
"num_warps": 4,
1960+
"num_stages": 4,
1961+
},
1962+
64: {
1963+
"BLOCK_SIZE_M": 16,
1964+
"BLOCK_SIZE_N": 64,
1965+
"BLOCK_SIZE_K": 128,
1966+
"GROUP_SIZE_M": 1,
1967+
"num_warps": 4,
1968+
"num_stages": 4,
1969+
},
1970+
96: {
1971+
"BLOCK_SIZE_M": 16,
1972+
"BLOCK_SIZE_N": 64,
1973+
"BLOCK_SIZE_K": 128,
1974+
"GROUP_SIZE_M": 1,
1975+
"num_warps": 4,
1976+
"num_stages": 3,
1977+
},
1978+
128: {
1979+
"BLOCK_SIZE_M": 16,
1980+
"BLOCK_SIZE_N": 64,
1981+
"BLOCK_SIZE_K": 128,
1982+
"GROUP_SIZE_M": 1,
1983+
"num_warps": 4,
1984+
"num_stages": 3,
1985+
},
1986+
256: {
1987+
"BLOCK_SIZE_M": 32,
1988+
"BLOCK_SIZE_N": 128,
1989+
"BLOCK_SIZE_K": 64,
1990+
"GROUP_SIZE_M": 1,
1991+
"num_warps": 4,
1992+
"num_stages": 5,
1993+
},
1994+
512: {
1995+
"BLOCK_SIZE_M": 64,
1996+
"BLOCK_SIZE_N": 256,
1997+
"BLOCK_SIZE_K": 64,
1998+
"GROUP_SIZE_M": 1,
1999+
"num_warps": 8,
2000+
"num_stages": 5,
2001+
},
2002+
1024: {
2003+
"BLOCK_SIZE_M": 128,
2004+
"BLOCK_SIZE_N": 256,
2005+
"BLOCK_SIZE_K": 64,
2006+
"GROUP_SIZE_M": 1,
2007+
"num_warps": 8,
2008+
"num_stages": 4,
2009+
},
2010+
1536: {
2011+
"BLOCK_SIZE_M": 256,
2012+
"BLOCK_SIZE_N": 256,
2013+
"BLOCK_SIZE_K": 64,
2014+
"GROUP_SIZE_M": 1,
2015+
"num_warps": 8,
2016+
"num_stages": 3,
2017+
},
2018+
2048: {
2019+
"BLOCK_SIZE_M": 256,
2020+
"BLOCK_SIZE_N": 256,
2021+
"BLOCK_SIZE_K": 64,
2022+
"GROUP_SIZE_M": 1,
2023+
"num_warps": 8,
2024+
"num_stages": 3,
2025+
},
2026+
3072: {
2027+
"BLOCK_SIZE_M": 128,
2028+
"BLOCK_SIZE_N": 256,
2029+
"BLOCK_SIZE_K": 64,
2030+
"GROUP_SIZE_M": 1,
2031+
"num_warps": 8,
2032+
"num_stages": 4,
2033+
},
2034+
4096: {
2035+
"BLOCK_SIZE_M": 256,
2036+
"BLOCK_SIZE_N": 256,
2037+
"BLOCK_SIZE_K": 64,
2038+
"GROUP_SIZE_M": 1,
2039+
"num_warps": 8,
2040+
"num_stages": 3,
2041+
},
2042+
}
2043+
_SM90_CONFIGS = {
2044+
1: {
2045+
"BLOCK_SIZE_M": 16,
2046+
"BLOCK_SIZE_N": 32,
2047+
"BLOCK_SIZE_K": 128,
2048+
"GROUP_SIZE_M": 32,
2049+
"num_warps": 4,
2050+
"num_stages": 3,
2051+
},
2052+
2: {
2053+
"BLOCK_SIZE_M": 16,
2054+
"BLOCK_SIZE_N": 64,
2055+
"BLOCK_SIZE_K": 128,
2056+
"GROUP_SIZE_M": 16,
2057+
"num_warps": 4,
2058+
"num_stages": 5,
2059+
},
2060+
4: {
2061+
"BLOCK_SIZE_M": 16,
2062+
"BLOCK_SIZE_N": 64,
2063+
"BLOCK_SIZE_K": 128,
2064+
"GROUP_SIZE_M": 1,
2065+
"num_warps": 8,
2066+
"num_stages": 2,
2067+
},
2068+
8: {
2069+
"BLOCK_SIZE_M": 16,
2070+
"BLOCK_SIZE_N": 32,
2071+
"BLOCK_SIZE_K": 128,
2072+
"GROUP_SIZE_M": 64,
2073+
"num_warps": 4,
2074+
"num_stages": 5,
2075+
},
2076+
16: {
2077+
"BLOCK_SIZE_M": 16,
2078+
"BLOCK_SIZE_N": 64,
2079+
"BLOCK_SIZE_K": 256,
2080+
"GROUP_SIZE_M": 1,
2081+
"num_warps": 4,
2082+
"num_stages": 5,
2083+
},
2084+
24: {
2085+
"BLOCK_SIZE_M": 16,
2086+
"BLOCK_SIZE_N": 64,
2087+
"BLOCK_SIZE_K": 256,
2088+
"GROUP_SIZE_M": 1,
2089+
"num_warps": 4,
2090+
"num_stages": 5,
2091+
},
2092+
32: {
2093+
"BLOCK_SIZE_M": 16,
2094+
"BLOCK_SIZE_N": 32,
2095+
"BLOCK_SIZE_K": 256,
2096+
"GROUP_SIZE_M": 16,
2097+
"num_warps": 4,
2098+
"num_stages": 5,
2099+
},
2100+
48: {
2101+
"BLOCK_SIZE_M": 16,
2102+
"BLOCK_SIZE_N": 32,
2103+
"BLOCK_SIZE_K": 256,
2104+
"GROUP_SIZE_M": 1,
2105+
"num_warps": 4,
2106+
"num_stages": 5,
2107+
},
2108+
64: {
2109+
"BLOCK_SIZE_M": 16,
2110+
"BLOCK_SIZE_N": 32,
2111+
"BLOCK_SIZE_K": 256,
2112+
"GROUP_SIZE_M": 1,
2113+
"num_warps": 4,
2114+
"num_stages": 5,
2115+
},
2116+
96: {
2117+
"BLOCK_SIZE_M": 16,
2118+
"BLOCK_SIZE_N": 32,
2119+
"BLOCK_SIZE_K": 256,
2120+
"GROUP_SIZE_M": 1,
2121+
"num_warps": 4,
2122+
"num_stages": 5,
2123+
},
2124+
128: {
2125+
"BLOCK_SIZE_M": 32,
2126+
"BLOCK_SIZE_N": 32,
2127+
"BLOCK_SIZE_K": 128,
2128+
"GROUP_SIZE_M": 1,
2129+
"num_warps": 4,
2130+
"num_stages": 5,
2131+
},
2132+
256: {
2133+
"BLOCK_SIZE_M": 32,
2134+
"BLOCK_SIZE_N": 128,
2135+
"BLOCK_SIZE_K": 128,
2136+
"GROUP_SIZE_M": 1,
2137+
"num_warps": 8,
2138+
"num_stages": 2,
2139+
},
2140+
512: {
2141+
"BLOCK_SIZE_M": 64,
2142+
"BLOCK_SIZE_N": 128,
2143+
"BLOCK_SIZE_K": 128,
2144+
"GROUP_SIZE_M": 1,
2145+
"num_warps": 8,
2146+
"num_stages": 4,
2147+
},
2148+
1024: {
2149+
"BLOCK_SIZE_M": 128,
2150+
"BLOCK_SIZE_N": 128,
2151+
"BLOCK_SIZE_K": 64,
2152+
"GROUP_SIZE_M": 1,
2153+
"num_warps": 8,
2154+
"num_stages": 5,
2155+
},
2156+
1536: {
2157+
"BLOCK_SIZE_M": 128,
2158+
"BLOCK_SIZE_N": 256,
2159+
"BLOCK_SIZE_K": 64,
2160+
"GROUP_SIZE_M": 1,
2161+
"num_warps": 8,
2162+
"num_stages": 4,
2163+
},
2164+
2048: {
2165+
"BLOCK_SIZE_M": 128,
2166+
"BLOCK_SIZE_N": 256,
2167+
"BLOCK_SIZE_K": 64,
2168+
"GROUP_SIZE_M": 1,
2169+
"num_warps": 8,
2170+
"num_stages": 4,
2171+
},
2172+
3072: {
2173+
"BLOCK_SIZE_M": 128,
2174+
"BLOCK_SIZE_N": 256,
2175+
"BLOCK_SIZE_K": 64,
2176+
"GROUP_SIZE_M": 1,
2177+
"num_warps": 8,
2178+
"num_stages": 4,
2179+
},
2180+
4096: {
2181+
"BLOCK_SIZE_M": 128,
2182+
"BLOCK_SIZE_N": 256,
2183+
"BLOCK_SIZE_K": 64,
2184+
"GROUP_SIZE_M": 1,
2185+
"num_warps": 8,
2186+
"num_stages": 4,
2187+
},
19252188
}
19262189

2190+
configs = _SM100_CONFIGS if get_sm_version() >= 100 else _SM90_CONFIGS
2191+
best_key = min(configs.keys(), key=lambda x: abs(x - M))
2192+
return configs[best_key]
2193+
19272194
def apply_tp(
19282195
self,
19292196
layer: nn.Layer,
@@ -2094,4 +2361,4 @@ def apply_ep_prefill(
20942361
def apply_ep_decode(
20952362
self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None
20962363
):
2097-
raise NotImplementedError("TritonMoEMethod does not support EP decode yet.")
2364+
return self._apply_ep_no_deepep(layer, x, gate, topk_ids_hookfunc)

0 commit comments

Comments
 (0)