@@ -1879,51 +1879,318 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
18791879
18801880 def _get_default_config (self , M : int , E : int ) -> dict :
18811881 """
1882- Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1883- `get_default_config` (bf16/fp16 non-block_shape branch).
1884- See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1882+ GPU-aware heuristic tile config for BF16 MoE.
18851883
1886- M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1887- E: number of (local) experts.
1888- """
1889-
1890- # Tile sizes scale with batch: small batches are memory-bound
1891- # (favor tall-K tiles), large batches are compute-bound (favor
1892- # large M/N tiles with more warps).
1893- if M <= 32 :
1894- block_m = 16
1895- elif M <= 96 :
1896- block_m = 32
1897- elif M <= 512 :
1898- block_m = 64
1899- else :
1900- block_m = 128
1901-
1902- block_n = 64 if M <= 64 else 128
1903-
1904- block_k = 64
1884+ Derived from SGLang's per-device tuned JSON configs for E=64, N=1856:
1885+ - SM100 (B200): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json
1886+ - SM90 (H100): triton_3_5_1/E=64,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json
19051887
1906- # Grouping adjacent M-blocks lets them share weight tiles in L2.
1907- # Only helps when there are enough M-blocks per expert to group;
1908- # with many experts each one sees few tokens so grouping is useless.
1909- tokens_per_expert = M // max (E , 1 )
1910- group_m = 16 if tokens_per_expert > 128 else 1
1888+ Config selection mirrors SGLang's try_get_optimal_moe_config:
1889+ pick the entry whose key is closest to M by absolute difference.
19111890
1912- # Large batches have enough blocks to saturate the GPU, so we
1913- # use more warps per block to increase arithmetic intensity.
1914- num_warps = 4 if M <= 128 else 8
1915-
1916- num_stages = 4 if M <= 32 else 3
1917-
1918- return {
1919- "BLOCK_SIZE_M" : block_m ,
1920- "BLOCK_SIZE_N" : block_n ,
1921- "BLOCK_SIZE_K" : block_k ,
1922- "GROUP_SIZE_M" : group_m ,
1923- "num_warps" : num_warps ,
1924- "num_stages" : num_stages ,
1891+ M: number of tokens (pre-expansion token count).
1892+ E: number of (local) experts.
1893+ """
1894+ from fastdeploy .model_executor .utils import get_sm_version
1895+
1896+ # SM100=B200 (sm_version>=100), SM90=H100, default to H100 on unknown GPU
1897+ _SM100_CONFIGS = {
1898+ 1 : {
1899+ "BLOCK_SIZE_M" : 16 ,
1900+ "BLOCK_SIZE_N" : 64 ,
1901+ "BLOCK_SIZE_K" : 64 ,
1902+ "GROUP_SIZE_M" : 32 ,
1903+ "num_warps" : 4 ,
1904+ "num_stages" : 5 ,
1905+ },
1906+ 2 : {
1907+ "BLOCK_SIZE_M" : 16 ,
1908+ "BLOCK_SIZE_N" : 64 ,
1909+ "BLOCK_SIZE_K" : 128 ,
1910+ "GROUP_SIZE_M" : 32 ,
1911+ "num_warps" : 4 ,
1912+ "num_stages" : 3 ,
1913+ },
1914+ 4 : {
1915+ "BLOCK_SIZE_M" : 16 ,
1916+ "BLOCK_SIZE_N" : 64 ,
1917+ "BLOCK_SIZE_K" : 128 ,
1918+ "GROUP_SIZE_M" : 64 ,
1919+ "num_warps" : 4 ,
1920+ "num_stages" : 4 ,
1921+ },
1922+ 8 : {
1923+ "BLOCK_SIZE_M" : 16 ,
1924+ "BLOCK_SIZE_N" : 64 ,
1925+ "BLOCK_SIZE_K" : 128 ,
1926+ "GROUP_SIZE_M" : 32 ,
1927+ "num_warps" : 4 ,
1928+ "num_stages" : 3 ,
1929+ },
1930+ 16 : {
1931+ "BLOCK_SIZE_M" : 16 ,
1932+ "BLOCK_SIZE_N" : 64 ,
1933+ "BLOCK_SIZE_K" : 128 ,
1934+ "GROUP_SIZE_M" : 1 ,
1935+ "num_warps" : 4 ,
1936+ "num_stages" : 3 ,
1937+ },
1938+ 24 : {
1939+ "BLOCK_SIZE_M" : 16 ,
1940+ "BLOCK_SIZE_N" : 128 ,
1941+ "BLOCK_SIZE_K" : 128 ,
1942+ "GROUP_SIZE_M" : 16 ,
1943+ "num_warps" : 4 ,
1944+ "num_stages" : 4 ,
1945+ },
1946+ 32 : {
1947+ "BLOCK_SIZE_M" : 16 ,
1948+ "BLOCK_SIZE_N" : 64 ,
1949+ "BLOCK_SIZE_K" : 128 ,
1950+ "GROUP_SIZE_M" : 16 ,
1951+ "num_warps" : 4 ,
1952+ "num_stages" : 4 ,
1953+ },
1954+ 48 : {
1955+ "BLOCK_SIZE_M" : 16 ,
1956+ "BLOCK_SIZE_N" : 64 ,
1957+ "BLOCK_SIZE_K" : 128 ,
1958+ "GROUP_SIZE_M" : 1 ,
1959+ "num_warps" : 4 ,
1960+ "num_stages" : 4 ,
1961+ },
1962+ 64 : {
1963+ "BLOCK_SIZE_M" : 16 ,
1964+ "BLOCK_SIZE_N" : 64 ,
1965+ "BLOCK_SIZE_K" : 128 ,
1966+ "GROUP_SIZE_M" : 1 ,
1967+ "num_warps" : 4 ,
1968+ "num_stages" : 4 ,
1969+ },
1970+ 96 : {
1971+ "BLOCK_SIZE_M" : 16 ,
1972+ "BLOCK_SIZE_N" : 64 ,
1973+ "BLOCK_SIZE_K" : 128 ,
1974+ "GROUP_SIZE_M" : 1 ,
1975+ "num_warps" : 4 ,
1976+ "num_stages" : 3 ,
1977+ },
1978+ 128 : {
1979+ "BLOCK_SIZE_M" : 16 ,
1980+ "BLOCK_SIZE_N" : 64 ,
1981+ "BLOCK_SIZE_K" : 128 ,
1982+ "GROUP_SIZE_M" : 1 ,
1983+ "num_warps" : 4 ,
1984+ "num_stages" : 3 ,
1985+ },
1986+ 256 : {
1987+ "BLOCK_SIZE_M" : 32 ,
1988+ "BLOCK_SIZE_N" : 128 ,
1989+ "BLOCK_SIZE_K" : 64 ,
1990+ "GROUP_SIZE_M" : 1 ,
1991+ "num_warps" : 4 ,
1992+ "num_stages" : 5 ,
1993+ },
1994+ 512 : {
1995+ "BLOCK_SIZE_M" : 64 ,
1996+ "BLOCK_SIZE_N" : 256 ,
1997+ "BLOCK_SIZE_K" : 64 ,
1998+ "GROUP_SIZE_M" : 1 ,
1999+ "num_warps" : 8 ,
2000+ "num_stages" : 5 ,
2001+ },
2002+ 1024 : {
2003+ "BLOCK_SIZE_M" : 128 ,
2004+ "BLOCK_SIZE_N" : 256 ,
2005+ "BLOCK_SIZE_K" : 64 ,
2006+ "GROUP_SIZE_M" : 1 ,
2007+ "num_warps" : 8 ,
2008+ "num_stages" : 4 ,
2009+ },
2010+ 1536 : {
2011+ "BLOCK_SIZE_M" : 256 ,
2012+ "BLOCK_SIZE_N" : 256 ,
2013+ "BLOCK_SIZE_K" : 64 ,
2014+ "GROUP_SIZE_M" : 1 ,
2015+ "num_warps" : 8 ,
2016+ "num_stages" : 3 ,
2017+ },
2018+ 2048 : {
2019+ "BLOCK_SIZE_M" : 256 ,
2020+ "BLOCK_SIZE_N" : 256 ,
2021+ "BLOCK_SIZE_K" : 64 ,
2022+ "GROUP_SIZE_M" : 1 ,
2023+ "num_warps" : 8 ,
2024+ "num_stages" : 3 ,
2025+ },
2026+ 3072 : {
2027+ "BLOCK_SIZE_M" : 128 ,
2028+ "BLOCK_SIZE_N" : 256 ,
2029+ "BLOCK_SIZE_K" : 64 ,
2030+ "GROUP_SIZE_M" : 1 ,
2031+ "num_warps" : 8 ,
2032+ "num_stages" : 4 ,
2033+ },
2034+ 4096 : {
2035+ "BLOCK_SIZE_M" : 256 ,
2036+ "BLOCK_SIZE_N" : 256 ,
2037+ "BLOCK_SIZE_K" : 64 ,
2038+ "GROUP_SIZE_M" : 1 ,
2039+ "num_warps" : 8 ,
2040+ "num_stages" : 3 ,
2041+ },
2042+ }
2043+ _SM90_CONFIGS = {
2044+ 1 : {
2045+ "BLOCK_SIZE_M" : 16 ,
2046+ "BLOCK_SIZE_N" : 32 ,
2047+ "BLOCK_SIZE_K" : 128 ,
2048+ "GROUP_SIZE_M" : 32 ,
2049+ "num_warps" : 4 ,
2050+ "num_stages" : 3 ,
2051+ },
2052+ 2 : {
2053+ "BLOCK_SIZE_M" : 16 ,
2054+ "BLOCK_SIZE_N" : 64 ,
2055+ "BLOCK_SIZE_K" : 128 ,
2056+ "GROUP_SIZE_M" : 16 ,
2057+ "num_warps" : 4 ,
2058+ "num_stages" : 5 ,
2059+ },
2060+ 4 : {
2061+ "BLOCK_SIZE_M" : 16 ,
2062+ "BLOCK_SIZE_N" : 64 ,
2063+ "BLOCK_SIZE_K" : 128 ,
2064+ "GROUP_SIZE_M" : 1 ,
2065+ "num_warps" : 8 ,
2066+ "num_stages" : 2 ,
2067+ },
2068+ 8 : {
2069+ "BLOCK_SIZE_M" : 16 ,
2070+ "BLOCK_SIZE_N" : 32 ,
2071+ "BLOCK_SIZE_K" : 128 ,
2072+ "GROUP_SIZE_M" : 64 ,
2073+ "num_warps" : 4 ,
2074+ "num_stages" : 5 ,
2075+ },
2076+ 16 : {
2077+ "BLOCK_SIZE_M" : 16 ,
2078+ "BLOCK_SIZE_N" : 64 ,
2079+ "BLOCK_SIZE_K" : 256 ,
2080+ "GROUP_SIZE_M" : 1 ,
2081+ "num_warps" : 4 ,
2082+ "num_stages" : 5 ,
2083+ },
2084+ 24 : {
2085+ "BLOCK_SIZE_M" : 16 ,
2086+ "BLOCK_SIZE_N" : 64 ,
2087+ "BLOCK_SIZE_K" : 256 ,
2088+ "GROUP_SIZE_M" : 1 ,
2089+ "num_warps" : 4 ,
2090+ "num_stages" : 5 ,
2091+ },
2092+ 32 : {
2093+ "BLOCK_SIZE_M" : 16 ,
2094+ "BLOCK_SIZE_N" : 32 ,
2095+ "BLOCK_SIZE_K" : 256 ,
2096+ "GROUP_SIZE_M" : 16 ,
2097+ "num_warps" : 4 ,
2098+ "num_stages" : 5 ,
2099+ },
2100+ 48 : {
2101+ "BLOCK_SIZE_M" : 16 ,
2102+ "BLOCK_SIZE_N" : 32 ,
2103+ "BLOCK_SIZE_K" : 256 ,
2104+ "GROUP_SIZE_M" : 1 ,
2105+ "num_warps" : 4 ,
2106+ "num_stages" : 5 ,
2107+ },
2108+ 64 : {
2109+ "BLOCK_SIZE_M" : 16 ,
2110+ "BLOCK_SIZE_N" : 32 ,
2111+ "BLOCK_SIZE_K" : 256 ,
2112+ "GROUP_SIZE_M" : 1 ,
2113+ "num_warps" : 4 ,
2114+ "num_stages" : 5 ,
2115+ },
2116+ 96 : {
2117+ "BLOCK_SIZE_M" : 16 ,
2118+ "BLOCK_SIZE_N" : 32 ,
2119+ "BLOCK_SIZE_K" : 256 ,
2120+ "GROUP_SIZE_M" : 1 ,
2121+ "num_warps" : 4 ,
2122+ "num_stages" : 5 ,
2123+ },
2124+ 128 : {
2125+ "BLOCK_SIZE_M" : 32 ,
2126+ "BLOCK_SIZE_N" : 32 ,
2127+ "BLOCK_SIZE_K" : 128 ,
2128+ "GROUP_SIZE_M" : 1 ,
2129+ "num_warps" : 4 ,
2130+ "num_stages" : 5 ,
2131+ },
2132+ 256 : {
2133+ "BLOCK_SIZE_M" : 32 ,
2134+ "BLOCK_SIZE_N" : 128 ,
2135+ "BLOCK_SIZE_K" : 128 ,
2136+ "GROUP_SIZE_M" : 1 ,
2137+ "num_warps" : 8 ,
2138+ "num_stages" : 2 ,
2139+ },
2140+ 512 : {
2141+ "BLOCK_SIZE_M" : 64 ,
2142+ "BLOCK_SIZE_N" : 128 ,
2143+ "BLOCK_SIZE_K" : 128 ,
2144+ "GROUP_SIZE_M" : 1 ,
2145+ "num_warps" : 8 ,
2146+ "num_stages" : 4 ,
2147+ },
2148+ 1024 : {
2149+ "BLOCK_SIZE_M" : 128 ,
2150+ "BLOCK_SIZE_N" : 128 ,
2151+ "BLOCK_SIZE_K" : 64 ,
2152+ "GROUP_SIZE_M" : 1 ,
2153+ "num_warps" : 8 ,
2154+ "num_stages" : 5 ,
2155+ },
2156+ 1536 : {
2157+ "BLOCK_SIZE_M" : 128 ,
2158+ "BLOCK_SIZE_N" : 256 ,
2159+ "BLOCK_SIZE_K" : 64 ,
2160+ "GROUP_SIZE_M" : 1 ,
2161+ "num_warps" : 8 ,
2162+ "num_stages" : 4 ,
2163+ },
2164+ 2048 : {
2165+ "BLOCK_SIZE_M" : 128 ,
2166+ "BLOCK_SIZE_N" : 256 ,
2167+ "BLOCK_SIZE_K" : 64 ,
2168+ "GROUP_SIZE_M" : 1 ,
2169+ "num_warps" : 8 ,
2170+ "num_stages" : 4 ,
2171+ },
2172+ 3072 : {
2173+ "BLOCK_SIZE_M" : 128 ,
2174+ "BLOCK_SIZE_N" : 256 ,
2175+ "BLOCK_SIZE_K" : 64 ,
2176+ "GROUP_SIZE_M" : 1 ,
2177+ "num_warps" : 8 ,
2178+ "num_stages" : 4 ,
2179+ },
2180+ 4096 : {
2181+ "BLOCK_SIZE_M" : 128 ,
2182+ "BLOCK_SIZE_N" : 256 ,
2183+ "BLOCK_SIZE_K" : 64 ,
2184+ "GROUP_SIZE_M" : 1 ,
2185+ "num_warps" : 8 ,
2186+ "num_stages" : 4 ,
2187+ },
19252188 }
19262189
2190+ configs = _SM100_CONFIGS if get_sm_version () >= 100 else _SM90_CONFIGS
2191+ best_key = min (configs .keys (), key = lambda x : abs (x - M ))
2192+ return configs [best_key ]
2193+
19272194 def apply_tp (
19282195 self ,
19292196 layer : nn .Layer ,
@@ -2094,4 +2361,4 @@ def apply_ep_prefill(
20942361 def apply_ep_decode (
20952362 self , layer , x , gate , topk_ids_hookfunc = None , shared_experts = None , fc1_latent_proj = None , fc2_latent_proj = None
20962363 ):
2097- raise NotImplementedError ( "TritonMoEMethod does not support EP decode yet." )
2364+ return self . _apply_ep_no_deepep ( layer , x , gate , topk_ids_hookfunc )
0 commit comments