@@ -1920,51 +1920,318 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
19201920
19211921 def _get_default_config (self , M : int , E : int ) -> dict :
19221922 """
1923- Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1924- `get_default_config` (bf16/fp16 non-block_shape branch).
1925- See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1923+ GPU-aware heuristic tile config for BF16 MoE.
19261924
1927- M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1928- E: number of (local) experts.
1929- """
1930-
1931- # Tile sizes scale with batch: small batches are memory-bound
1932- # (favor tall-K tiles), large batches are compute-bound (favor
1933- # large M/N tiles with more warps).
1934- if M <= 32 :
1935- block_m = 16
1936- elif M <= 96 :
1937- block_m = 32
1938- elif M <= 512 :
1939- block_m = 64
1940- else :
1941- block_m = 128
1942-
1943- block_n = 64 if M <= 64 else 128
1944-
1945- block_k = 64
1925+ Derived from SGLang's per-device tuned JSON configs for E=64, N=1856:
1926+ - SM100 (B200): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json
1927+ - SM90 (H100): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json
19461928
1947- # Grouping adjacent M-blocks lets them share weight tiles in L2.
1948- # Only helps when there are enough M-blocks per expert to group;
1949- # with many experts each one sees few tokens so grouping is useless.
1950- tokens_per_expert = M // max (E , 1 )
1951- group_m = 16 if tokens_per_expert > 128 else 1
1929+ Config selection mirrors SGLang's try_get_optimal_moe_config:
1930+ pick the entry whose key is closest to M by absolute difference.
19521931
1953- # Large batches have enough blocks to saturate the GPU, so we
1954- # use more warps per block to increase arithmetic intensity.
1955- num_warps = 4 if M <= 128 else 8
1956-
1957- num_stages = 4 if M <= 32 else 3
1958-
1959- return {
1960- "BLOCK_SIZE_M" : block_m ,
1961- "BLOCK_SIZE_N" : block_n ,
1962- "BLOCK_SIZE_K" : block_k ,
1963- "GROUP_SIZE_M" : group_m ,
1964- "num_warps" : num_warps ,
1965- "num_stages" : num_stages ,
1932+ M: number of tokens (pre-expansion token count).
1933+ E: number of (local) experts.
1934+ """
1935+ from fastdeploy .model_executor .utils import get_sm_version
1936+
1937+ # SM100=B200 (sm_version>=100), SM90=H100, default to H100 on unknown GPU
1938+ _SM100_CONFIGS = {
1939+ 1 : {
1940+ "BLOCK_SIZE_M" : 16 ,
1941+ "BLOCK_SIZE_N" : 64 ,
1942+ "BLOCK_SIZE_K" : 64 ,
1943+ "GROUP_SIZE_M" : 32 ,
1944+ "num_warps" : 4 ,
1945+ "num_stages" : 5 ,
1946+ },
1947+ 2 : {
1948+ "BLOCK_SIZE_M" : 16 ,
1949+ "BLOCK_SIZE_N" : 64 ,
1950+ "BLOCK_SIZE_K" : 128 ,
1951+ "GROUP_SIZE_M" : 32 ,
1952+ "num_warps" : 4 ,
1953+ "num_stages" : 3 ,
1954+ },
1955+ 4 : {
1956+ "BLOCK_SIZE_M" : 16 ,
1957+ "BLOCK_SIZE_N" : 64 ,
1958+ "BLOCK_SIZE_K" : 128 ,
1959+ "GROUP_SIZE_M" : 64 ,
1960+ "num_warps" : 4 ,
1961+ "num_stages" : 4 ,
1962+ },
1963+ 8 : {
1964+ "BLOCK_SIZE_M" : 16 ,
1965+ "BLOCK_SIZE_N" : 64 ,
1966+ "BLOCK_SIZE_K" : 128 ,
1967+ "GROUP_SIZE_M" : 32 ,
1968+ "num_warps" : 4 ,
1969+ "num_stages" : 3 ,
1970+ },
1971+ 16 : {
1972+ "BLOCK_SIZE_M" : 16 ,
1973+ "BLOCK_SIZE_N" : 64 ,
1974+ "BLOCK_SIZE_K" : 128 ,
1975+ "GROUP_SIZE_M" : 1 ,
1976+ "num_warps" : 4 ,
1977+ "num_stages" : 3 ,
1978+ },
1979+ 24 : {
1980+ "BLOCK_SIZE_M" : 16 ,
1981+ "BLOCK_SIZE_N" : 128 ,
1982+ "BLOCK_SIZE_K" : 128 ,
1983+ "GROUP_SIZE_M" : 16 ,
1984+ "num_warps" : 4 ,
1985+ "num_stages" : 4 ,
1986+ },
1987+ 32 : {
1988+ "BLOCK_SIZE_M" : 16 ,
1989+ "BLOCK_SIZE_N" : 64 ,
1990+ "BLOCK_SIZE_K" : 128 ,
1991+ "GROUP_SIZE_M" : 16 ,
1992+ "num_warps" : 4 ,
1993+ "num_stages" : 4 ,
1994+ },
1995+ 48 : {
1996+ "BLOCK_SIZE_M" : 16 ,
1997+ "BLOCK_SIZE_N" : 64 ,
1998+ "BLOCK_SIZE_K" : 128 ,
1999+ "GROUP_SIZE_M" : 1 ,
2000+ "num_warps" : 4 ,
2001+ "num_stages" : 4 ,
2002+ },
2003+ 64 : {
2004+ "BLOCK_SIZE_M" : 16 ,
2005+ "BLOCK_SIZE_N" : 64 ,
2006+ "BLOCK_SIZE_K" : 128 ,
2007+ "GROUP_SIZE_M" : 1 ,
2008+ "num_warps" : 4 ,
2009+ "num_stages" : 4 ,
2010+ },
2011+ 96 : {
2012+ "BLOCK_SIZE_M" : 16 ,
2013+ "BLOCK_SIZE_N" : 64 ,
2014+ "BLOCK_SIZE_K" : 128 ,
2015+ "GROUP_SIZE_M" : 1 ,
2016+ "num_warps" : 4 ,
2017+ "num_stages" : 3 ,
2018+ },
2019+ 128 : {
2020+ "BLOCK_SIZE_M" : 16 ,
2021+ "BLOCK_SIZE_N" : 64 ,
2022+ "BLOCK_SIZE_K" : 128 ,
2023+ "GROUP_SIZE_M" : 1 ,
2024+ "num_warps" : 4 ,
2025+ "num_stages" : 3 ,
2026+ },
2027+ 256 : {
2028+ "BLOCK_SIZE_M" : 32 ,
2029+ "BLOCK_SIZE_N" : 128 ,
2030+ "BLOCK_SIZE_K" : 64 ,
2031+ "GROUP_SIZE_M" : 1 ,
2032+ "num_warps" : 4 ,
2033+ "num_stages" : 5 ,
2034+ },
2035+ 512 : {
2036+ "BLOCK_SIZE_M" : 64 ,
2037+ "BLOCK_SIZE_N" : 256 ,
2038+ "BLOCK_SIZE_K" : 64 ,
2039+ "GROUP_SIZE_M" : 1 ,
2040+ "num_warps" : 8 ,
2041+ "num_stages" : 5 ,
2042+ },
2043+ 1024 : {
2044+ "BLOCK_SIZE_M" : 128 ,
2045+ "BLOCK_SIZE_N" : 256 ,
2046+ "BLOCK_SIZE_K" : 64 ,
2047+ "GROUP_SIZE_M" : 1 ,
2048+ "num_warps" : 8 ,
2049+ "num_stages" : 4 ,
2050+ },
2051+ 1536 : {
2052+ "BLOCK_SIZE_M" : 256 ,
2053+ "BLOCK_SIZE_N" : 256 ,
2054+ "BLOCK_SIZE_K" : 64 ,
2055+ "GROUP_SIZE_M" : 1 ,
2056+ "num_warps" : 8 ,
2057+ "num_stages" : 3 ,
2058+ },
2059+ 2048 : {
2060+ "BLOCK_SIZE_M" : 256 ,
2061+ "BLOCK_SIZE_N" : 256 ,
2062+ "BLOCK_SIZE_K" : 64 ,
2063+ "GROUP_SIZE_M" : 1 ,
2064+ "num_warps" : 8 ,
2065+ "num_stages" : 3 ,
2066+ },
2067+ 3072 : {
2068+ "BLOCK_SIZE_M" : 128 ,
2069+ "BLOCK_SIZE_N" : 256 ,
2070+ "BLOCK_SIZE_K" : 64 ,
2071+ "GROUP_SIZE_M" : 1 ,
2072+ "num_warps" : 8 ,
2073+ "num_stages" : 4 ,
2074+ },
2075+ 4096 : {
2076+ "BLOCK_SIZE_M" : 256 ,
2077+ "BLOCK_SIZE_N" : 256 ,
2078+ "BLOCK_SIZE_K" : 64 ,
2079+ "GROUP_SIZE_M" : 1 ,
2080+ "num_warps" : 8 ,
2081+ "num_stages" : 3 ,
2082+ },
2083+ }
2084+ _SM90_CONFIGS = {
2085+ 1 : {
2086+ "BLOCK_SIZE_M" : 16 ,
2087+ "BLOCK_SIZE_N" : 32 ,
2088+ "BLOCK_SIZE_K" : 128 ,
2089+ "GROUP_SIZE_M" : 32 ,
2090+ "num_warps" : 4 ,
2091+ "num_stages" : 3 ,
2092+ },
2093+ 2 : {
2094+ "BLOCK_SIZE_M" : 16 ,
2095+ "BLOCK_SIZE_N" : 64 ,
2096+ "BLOCK_SIZE_K" : 128 ,
2097+ "GROUP_SIZE_M" : 16 ,
2098+ "num_warps" : 4 ,
2099+ "num_stages" : 5 ,
2100+ },
2101+ 4 : {
2102+ "BLOCK_SIZE_M" : 16 ,
2103+ "BLOCK_SIZE_N" : 64 ,
2104+ "BLOCK_SIZE_K" : 128 ,
2105+ "GROUP_SIZE_M" : 1 ,
2106+ "num_warps" : 8 ,
2107+ "num_stages" : 2 ,
2108+ },
2109+ 8 : {
2110+ "BLOCK_SIZE_M" : 16 ,
2111+ "BLOCK_SIZE_N" : 32 ,
2112+ "BLOCK_SIZE_K" : 128 ,
2113+ "GROUP_SIZE_M" : 64 ,
2114+ "num_warps" : 4 ,
2115+ "num_stages" : 5 ,
2116+ },
2117+ 16 : {
2118+ "BLOCK_SIZE_M" : 16 ,
2119+ "BLOCK_SIZE_N" : 64 ,
2120+ "BLOCK_SIZE_K" : 256 ,
2121+ "GROUP_SIZE_M" : 1 ,
2122+ "num_warps" : 4 ,
2123+ "num_stages" : 5 ,
2124+ },
2125+ 24 : {
2126+ "BLOCK_SIZE_M" : 16 ,
2127+ "BLOCK_SIZE_N" : 64 ,
2128+ "BLOCK_SIZE_K" : 256 ,
2129+ "GROUP_SIZE_M" : 1 ,
2130+ "num_warps" : 4 ,
2131+ "num_stages" : 5 ,
2132+ },
2133+ 32 : {
2134+ "BLOCK_SIZE_M" : 16 ,
2135+ "BLOCK_SIZE_N" : 32 ,
2136+ "BLOCK_SIZE_K" : 256 ,
2137+ "GROUP_SIZE_M" : 16 ,
2138+ "num_warps" : 4 ,
2139+ "num_stages" : 5 ,
2140+ },
2141+ 48 : {
2142+ "BLOCK_SIZE_M" : 16 ,
2143+ "BLOCK_SIZE_N" : 32 ,
2144+ "BLOCK_SIZE_K" : 256 ,
2145+ "GROUP_SIZE_M" : 1 ,
2146+ "num_warps" : 4 ,
2147+ "num_stages" : 5 ,
2148+ },
2149+ 64 : {
2150+ "BLOCK_SIZE_M" : 16 ,
2151+ "BLOCK_SIZE_N" : 32 ,
2152+ "BLOCK_SIZE_K" : 256 ,
2153+ "GROUP_SIZE_M" : 1 ,
2154+ "num_warps" : 4 ,
2155+ "num_stages" : 5 ,
2156+ },
2157+ 96 : {
2158+ "BLOCK_SIZE_M" : 16 ,
2159+ "BLOCK_SIZE_N" : 32 ,
2160+ "BLOCK_SIZE_K" : 256 ,
2161+ "GROUP_SIZE_M" : 1 ,
2162+ "num_warps" : 4 ,
2163+ "num_stages" : 5 ,
2164+ },
2165+ 128 : {
2166+ "BLOCK_SIZE_M" : 32 ,
2167+ "BLOCK_SIZE_N" : 32 ,
2168+ "BLOCK_SIZE_K" : 128 ,
2169+ "GROUP_SIZE_M" : 1 ,
2170+ "num_warps" : 4 ,
2171+ "num_stages" : 5 ,
2172+ },
2173+ 256 : {
2174+ "BLOCK_SIZE_M" : 32 ,
2175+ "BLOCK_SIZE_N" : 128 ,
2176+ "BLOCK_SIZE_K" : 128 ,
2177+ "GROUP_SIZE_M" : 1 ,
2178+ "num_warps" : 8 ,
2179+ "num_stages" : 2 ,
2180+ },
2181+ 512 : {
2182+ "BLOCK_SIZE_M" : 64 ,
2183+ "BLOCK_SIZE_N" : 128 ,
2184+ "BLOCK_SIZE_K" : 128 ,
2185+ "GROUP_SIZE_M" : 1 ,
2186+ "num_warps" : 8 ,
2187+ "num_stages" : 4 ,
2188+ },
2189+ 1024 : {
2190+ "BLOCK_SIZE_M" : 128 ,
2191+ "BLOCK_SIZE_N" : 128 ,
2192+ "BLOCK_SIZE_K" : 64 ,
2193+ "GROUP_SIZE_M" : 1 ,
2194+ "num_warps" : 8 ,
2195+ "num_stages" : 5 ,
2196+ },
2197+ 1536 : {
2198+ "BLOCK_SIZE_M" : 128 ,
2199+ "BLOCK_SIZE_N" : 256 ,
2200+ "BLOCK_SIZE_K" : 64 ,
2201+ "GROUP_SIZE_M" : 1 ,
2202+ "num_warps" : 8 ,
2203+ "num_stages" : 4 ,
2204+ },
2205+ 2048 : {
2206+ "BLOCK_SIZE_M" : 128 ,
2207+ "BLOCK_SIZE_N" : 256 ,
2208+ "BLOCK_SIZE_K" : 64 ,
2209+ "GROUP_SIZE_M" : 1 ,
2210+ "num_warps" : 8 ,
2211+ "num_stages" : 4 ,
2212+ },
2213+ 3072 : {
2214+ "BLOCK_SIZE_M" : 128 ,
2215+ "BLOCK_SIZE_N" : 256 ,
2216+ "BLOCK_SIZE_K" : 64 ,
2217+ "GROUP_SIZE_M" : 1 ,
2218+ "num_warps" : 8 ,
2219+ "num_stages" : 4 ,
2220+ },
2221+ 4096 : {
2222+ "BLOCK_SIZE_M" : 128 ,
2223+ "BLOCK_SIZE_N" : 256 ,
2224+ "BLOCK_SIZE_K" : 64 ,
2225+ "GROUP_SIZE_M" : 1 ,
2226+ "num_warps" : 8 ,
2227+ "num_stages" : 4 ,
2228+ },
19662229 }
19672230
2231+ configs = _SM100_CONFIGS if get_sm_version () >= 100 else _SM90_CONFIGS
2232+ best_key = min (configs .keys (), key = lambda x : abs (x - M ))
2233+ return configs [best_key ]
2234+
19682235 def apply_tp (
19692236 self ,
19702237 layer : nn .Layer ,
0 commit comments