Skip to content

Commit 3ca8136

Browse files
upgrade triton moe config.
1 parent 529ec9e commit 3ca8136

1 file changed

Lines changed: 307 additions & 40 deletions

File tree

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 307 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,51 +1920,318 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
19201920

19211921
def _get_default_config(self, M: int, E: int) -> dict:
19221922
"""
1923-
Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1924-
`get_default_config` (bf16/fp16 non-block_shape branch).
1925-
See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1923+
GPU-aware heuristic tile config for BF16 MoE.
19261924
1927-
M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1928-
E: number of (local) experts.
1929-
"""
1930-
1931-
# Tile sizes scale with batch: small batches are memory-bound
1932-
# (favor tall-K tiles), large batches are compute-bound (favor
1933-
# large M/N tiles with more warps).
1934-
if M <= 32:
1935-
block_m = 16
1936-
elif M <= 96:
1937-
block_m = 32
1938-
elif M <= 512:
1939-
block_m = 64
1940-
else:
1941-
block_m = 128
1942-
1943-
block_n = 64 if M <= 64 else 128
1944-
1945-
block_k = 64
1925+
Derived from SGLang's per-device tuned JSON configs for E=64, N=1856:
1926+
- SM100 (B200): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json
1927+
- SM90 (H100): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json
19461928
1947-
# Grouping adjacent M-blocks lets them share weight tiles in L2.
1948-
# Only helps when there are enough M-blocks per expert to group;
1949-
# with many experts each one sees few tokens so grouping is useless.
1950-
tokens_per_expert = M // max(E, 1)
1951-
group_m = 16 if tokens_per_expert > 128 else 1
1929+
Config selection mirrors SGLang's try_get_optimal_moe_config:
1930+
pick the entry whose key is closest to M by absolute difference.
19521931
1953-
# Large batches have enough blocks to saturate the GPU, so we
1954-
# use more warps per block to increase arithmetic intensity.
1955-
num_warps = 4 if M <= 128 else 8
1956-
1957-
num_stages = 4 if M <= 32 else 3
1958-
1959-
return {
1960-
"BLOCK_SIZE_M": block_m,
1961-
"BLOCK_SIZE_N": block_n,
1962-
"BLOCK_SIZE_K": block_k,
1963-
"GROUP_SIZE_M": group_m,
1964-
"num_warps": num_warps,
1965-
"num_stages": num_stages,
1932+
M: number of tokens (pre-expansion token count).
1933+
E: number of (local) experts.
1934+
"""
1935+
from fastdeploy.model_executor.utils import get_sm_version
1936+
1937+
# SM100=B200 (sm_version>=100), SM90=H100, default to H100 on unknown GPU
1938+
_SM100_CONFIGS = {
1939+
1: {
1940+
"BLOCK_SIZE_M": 16,
1941+
"BLOCK_SIZE_N": 64,
1942+
"BLOCK_SIZE_K": 64,
1943+
"GROUP_SIZE_M": 32,
1944+
"num_warps": 4,
1945+
"num_stages": 5,
1946+
},
1947+
2: {
1948+
"BLOCK_SIZE_M": 16,
1949+
"BLOCK_SIZE_N": 64,
1950+
"BLOCK_SIZE_K": 128,
1951+
"GROUP_SIZE_M": 32,
1952+
"num_warps": 4,
1953+
"num_stages": 3,
1954+
},
1955+
4: {
1956+
"BLOCK_SIZE_M": 16,
1957+
"BLOCK_SIZE_N": 64,
1958+
"BLOCK_SIZE_K": 128,
1959+
"GROUP_SIZE_M": 64,
1960+
"num_warps": 4,
1961+
"num_stages": 4,
1962+
},
1963+
8: {
1964+
"BLOCK_SIZE_M": 16,
1965+
"BLOCK_SIZE_N": 64,
1966+
"BLOCK_SIZE_K": 128,
1967+
"GROUP_SIZE_M": 32,
1968+
"num_warps": 4,
1969+
"num_stages": 3,
1970+
},
1971+
16: {
1972+
"BLOCK_SIZE_M": 16,
1973+
"BLOCK_SIZE_N": 64,
1974+
"BLOCK_SIZE_K": 128,
1975+
"GROUP_SIZE_M": 1,
1976+
"num_warps": 4,
1977+
"num_stages": 3,
1978+
},
1979+
24: {
1980+
"BLOCK_SIZE_M": 16,
1981+
"BLOCK_SIZE_N": 128,
1982+
"BLOCK_SIZE_K": 128,
1983+
"GROUP_SIZE_M": 16,
1984+
"num_warps": 4,
1985+
"num_stages": 4,
1986+
},
1987+
32: {
1988+
"BLOCK_SIZE_M": 16,
1989+
"BLOCK_SIZE_N": 64,
1990+
"BLOCK_SIZE_K": 128,
1991+
"GROUP_SIZE_M": 16,
1992+
"num_warps": 4,
1993+
"num_stages": 4,
1994+
},
1995+
48: {
1996+
"BLOCK_SIZE_M": 16,
1997+
"BLOCK_SIZE_N": 64,
1998+
"BLOCK_SIZE_K": 128,
1999+
"GROUP_SIZE_M": 1,
2000+
"num_warps": 4,
2001+
"num_stages": 4,
2002+
},
2003+
64: {
2004+
"BLOCK_SIZE_M": 16,
2005+
"BLOCK_SIZE_N": 64,
2006+
"BLOCK_SIZE_K": 128,
2007+
"GROUP_SIZE_M": 1,
2008+
"num_warps": 4,
2009+
"num_stages": 4,
2010+
},
2011+
96: {
2012+
"BLOCK_SIZE_M": 16,
2013+
"BLOCK_SIZE_N": 64,
2014+
"BLOCK_SIZE_K": 128,
2015+
"GROUP_SIZE_M": 1,
2016+
"num_warps": 4,
2017+
"num_stages": 3,
2018+
},
2019+
128: {
2020+
"BLOCK_SIZE_M": 16,
2021+
"BLOCK_SIZE_N": 64,
2022+
"BLOCK_SIZE_K": 128,
2023+
"GROUP_SIZE_M": 1,
2024+
"num_warps": 4,
2025+
"num_stages": 3,
2026+
},
2027+
256: {
2028+
"BLOCK_SIZE_M": 32,
2029+
"BLOCK_SIZE_N": 128,
2030+
"BLOCK_SIZE_K": 64,
2031+
"GROUP_SIZE_M": 1,
2032+
"num_warps": 4,
2033+
"num_stages": 5,
2034+
},
2035+
512: {
2036+
"BLOCK_SIZE_M": 64,
2037+
"BLOCK_SIZE_N": 256,
2038+
"BLOCK_SIZE_K": 64,
2039+
"GROUP_SIZE_M": 1,
2040+
"num_warps": 8,
2041+
"num_stages": 5,
2042+
},
2043+
1024: {
2044+
"BLOCK_SIZE_M": 128,
2045+
"BLOCK_SIZE_N": 256,
2046+
"BLOCK_SIZE_K": 64,
2047+
"GROUP_SIZE_M": 1,
2048+
"num_warps": 8,
2049+
"num_stages": 4,
2050+
},
2051+
1536: {
2052+
"BLOCK_SIZE_M": 256,
2053+
"BLOCK_SIZE_N": 256,
2054+
"BLOCK_SIZE_K": 64,
2055+
"GROUP_SIZE_M": 1,
2056+
"num_warps": 8,
2057+
"num_stages": 3,
2058+
},
2059+
2048: {
2060+
"BLOCK_SIZE_M": 256,
2061+
"BLOCK_SIZE_N": 256,
2062+
"BLOCK_SIZE_K": 64,
2063+
"GROUP_SIZE_M": 1,
2064+
"num_warps": 8,
2065+
"num_stages": 3,
2066+
},
2067+
3072: {
2068+
"BLOCK_SIZE_M": 128,
2069+
"BLOCK_SIZE_N": 256,
2070+
"BLOCK_SIZE_K": 64,
2071+
"GROUP_SIZE_M": 1,
2072+
"num_warps": 8,
2073+
"num_stages": 4,
2074+
},
2075+
4096: {
2076+
"BLOCK_SIZE_M": 256,
2077+
"BLOCK_SIZE_N": 256,
2078+
"BLOCK_SIZE_K": 64,
2079+
"GROUP_SIZE_M": 1,
2080+
"num_warps": 8,
2081+
"num_stages": 3,
2082+
},
2083+
}
2084+
_SM90_CONFIGS = {
2085+
1: {
2086+
"BLOCK_SIZE_M": 16,
2087+
"BLOCK_SIZE_N": 32,
2088+
"BLOCK_SIZE_K": 128,
2089+
"GROUP_SIZE_M": 32,
2090+
"num_warps": 4,
2091+
"num_stages": 3,
2092+
},
2093+
2: {
2094+
"BLOCK_SIZE_M": 16,
2095+
"BLOCK_SIZE_N": 64,
2096+
"BLOCK_SIZE_K": 128,
2097+
"GROUP_SIZE_M": 16,
2098+
"num_warps": 4,
2099+
"num_stages": 5,
2100+
},
2101+
4: {
2102+
"BLOCK_SIZE_M": 16,
2103+
"BLOCK_SIZE_N": 64,
2104+
"BLOCK_SIZE_K": 128,
2105+
"GROUP_SIZE_M": 1,
2106+
"num_warps": 8,
2107+
"num_stages": 2,
2108+
},
2109+
8: {
2110+
"BLOCK_SIZE_M": 16,
2111+
"BLOCK_SIZE_N": 32,
2112+
"BLOCK_SIZE_K": 128,
2113+
"GROUP_SIZE_M": 64,
2114+
"num_warps": 4,
2115+
"num_stages": 5,
2116+
},
2117+
16: {
2118+
"BLOCK_SIZE_M": 16,
2119+
"BLOCK_SIZE_N": 64,
2120+
"BLOCK_SIZE_K": 256,
2121+
"GROUP_SIZE_M": 1,
2122+
"num_warps": 4,
2123+
"num_stages": 5,
2124+
},
2125+
24: {
2126+
"BLOCK_SIZE_M": 16,
2127+
"BLOCK_SIZE_N": 64,
2128+
"BLOCK_SIZE_K": 256,
2129+
"GROUP_SIZE_M": 1,
2130+
"num_warps": 4,
2131+
"num_stages": 5,
2132+
},
2133+
32: {
2134+
"BLOCK_SIZE_M": 16,
2135+
"BLOCK_SIZE_N": 32,
2136+
"BLOCK_SIZE_K": 256,
2137+
"GROUP_SIZE_M": 16,
2138+
"num_warps": 4,
2139+
"num_stages": 5,
2140+
},
2141+
48: {
2142+
"BLOCK_SIZE_M": 16,
2143+
"BLOCK_SIZE_N": 32,
2144+
"BLOCK_SIZE_K": 256,
2145+
"GROUP_SIZE_M": 1,
2146+
"num_warps": 4,
2147+
"num_stages": 5,
2148+
},
2149+
64: {
2150+
"BLOCK_SIZE_M": 16,
2151+
"BLOCK_SIZE_N": 32,
2152+
"BLOCK_SIZE_K": 256,
2153+
"GROUP_SIZE_M": 1,
2154+
"num_warps": 4,
2155+
"num_stages": 5,
2156+
},
2157+
96: {
2158+
"BLOCK_SIZE_M": 16,
2159+
"BLOCK_SIZE_N": 32,
2160+
"BLOCK_SIZE_K": 256,
2161+
"GROUP_SIZE_M": 1,
2162+
"num_warps": 4,
2163+
"num_stages": 5,
2164+
},
2165+
128: {
2166+
"BLOCK_SIZE_M": 32,
2167+
"BLOCK_SIZE_N": 32,
2168+
"BLOCK_SIZE_K": 128,
2169+
"GROUP_SIZE_M": 1,
2170+
"num_warps": 4,
2171+
"num_stages": 5,
2172+
},
2173+
256: {
2174+
"BLOCK_SIZE_M": 32,
2175+
"BLOCK_SIZE_N": 128,
2176+
"BLOCK_SIZE_K": 128,
2177+
"GROUP_SIZE_M": 1,
2178+
"num_warps": 8,
2179+
"num_stages": 2,
2180+
},
2181+
512: {
2182+
"BLOCK_SIZE_M": 64,
2183+
"BLOCK_SIZE_N": 128,
2184+
"BLOCK_SIZE_K": 128,
2185+
"GROUP_SIZE_M": 1,
2186+
"num_warps": 8,
2187+
"num_stages": 4,
2188+
},
2189+
1024: {
2190+
"BLOCK_SIZE_M": 128,
2191+
"BLOCK_SIZE_N": 128,
2192+
"BLOCK_SIZE_K": 64,
2193+
"GROUP_SIZE_M": 1,
2194+
"num_warps": 8,
2195+
"num_stages": 5,
2196+
},
2197+
1536: {
2198+
"BLOCK_SIZE_M": 128,
2199+
"BLOCK_SIZE_N": 256,
2200+
"BLOCK_SIZE_K": 64,
2201+
"GROUP_SIZE_M": 1,
2202+
"num_warps": 8,
2203+
"num_stages": 4,
2204+
},
2205+
2048: {
2206+
"BLOCK_SIZE_M": 128,
2207+
"BLOCK_SIZE_N": 256,
2208+
"BLOCK_SIZE_K": 64,
2209+
"GROUP_SIZE_M": 1,
2210+
"num_warps": 8,
2211+
"num_stages": 4,
2212+
},
2213+
3072: {
2214+
"BLOCK_SIZE_M": 128,
2215+
"BLOCK_SIZE_N": 256,
2216+
"BLOCK_SIZE_K": 64,
2217+
"GROUP_SIZE_M": 1,
2218+
"num_warps": 8,
2219+
"num_stages": 4,
2220+
},
2221+
4096: {
2222+
"BLOCK_SIZE_M": 128,
2223+
"BLOCK_SIZE_N": 256,
2224+
"BLOCK_SIZE_K": 64,
2225+
"GROUP_SIZE_M": 1,
2226+
"num_warps": 8,
2227+
"num_stages": 4,
2228+
},
19662229
}
19672230

2231+
configs = _SM100_CONFIGS if get_sm_version() >= 100 else _SM90_CONFIGS
2232+
best_key = min(configs.keys(), key=lambda x: abs(x - M))
2233+
return configs[best_key]
2234+
19682235
def apply_tp(
19692236
self,
19702237
layer: nn.Layer,

0 commit comments

Comments
 (0)