@@ -1879,17 +1879,169 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
18791879
18801880 def _get_default_config (self , M : int , E : int ) -> dict :
18811881 """
1882- Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1883- `get_default_config` (bf16/fp16 non-block_shape branch).
1884- See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1882+ GPU-aware heuristic tile config for BF16 MoE.
18851883
1886- M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1884+ SM100 (B200): nearest-key lookup from SGLang tuned config
1885+ (triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json).
1886+ Others: original vLLM-ported heuristic.
1887+
1888+ M: number of tokens (pre-expansion token count).
18871889 E: number of (local) experts.
18881890 """
1891+ from fastdeploy .model_executor .utils import get_sm_version
1892+
1893+ if get_sm_version () >= 100 :
1894+ # SM100 (B200): use SGLang tuned lookup, nearest key by abs diff
1895+ _SM100_CONFIGS = {
1896+ 1 : {
1897+ "BLOCK_SIZE_M" : 16 ,
1898+ "BLOCK_SIZE_N" : 64 ,
1899+ "BLOCK_SIZE_K" : 64 ,
1900+ "GROUP_SIZE_M" : 32 ,
1901+ "num_warps" : 4 ,
1902+ "num_stages" : 5 ,
1903+ },
1904+ 2 : {
1905+ "BLOCK_SIZE_M" : 16 ,
1906+ "BLOCK_SIZE_N" : 64 ,
1907+ "BLOCK_SIZE_K" : 128 ,
1908+ "GROUP_SIZE_M" : 32 ,
1909+ "num_warps" : 4 ,
1910+ "num_stages" : 3 ,
1911+ },
1912+ 4 : {
1913+ "BLOCK_SIZE_M" : 16 ,
1914+ "BLOCK_SIZE_N" : 64 ,
1915+ "BLOCK_SIZE_K" : 128 ,
1916+ "GROUP_SIZE_M" : 64 ,
1917+ "num_warps" : 4 ,
1918+ "num_stages" : 4 ,
1919+ },
1920+ 8 : {
1921+ "BLOCK_SIZE_M" : 16 ,
1922+ "BLOCK_SIZE_N" : 64 ,
1923+ "BLOCK_SIZE_K" : 128 ,
1924+ "GROUP_SIZE_M" : 32 ,
1925+ "num_warps" : 4 ,
1926+ "num_stages" : 3 ,
1927+ },
1928+ 16 : {
1929+ "BLOCK_SIZE_M" : 16 ,
1930+ "BLOCK_SIZE_N" : 64 ,
1931+ "BLOCK_SIZE_K" : 128 ,
1932+ "GROUP_SIZE_M" : 1 ,
1933+ "num_warps" : 4 ,
1934+ "num_stages" : 3 ,
1935+ },
1936+ 24 : {
1937+ "BLOCK_SIZE_M" : 16 ,
1938+ "BLOCK_SIZE_N" : 128 ,
1939+ "BLOCK_SIZE_K" : 128 ,
1940+ "GROUP_SIZE_M" : 16 ,
1941+ "num_warps" : 4 ,
1942+ "num_stages" : 4 ,
1943+ },
1944+ 32 : {
1945+ "BLOCK_SIZE_M" : 16 ,
1946+ "BLOCK_SIZE_N" : 64 ,
1947+ "BLOCK_SIZE_K" : 128 ,
1948+ "GROUP_SIZE_M" : 16 ,
1949+ "num_warps" : 4 ,
1950+ "num_stages" : 4 ,
1951+ },
1952+ 48 : {
1953+ "BLOCK_SIZE_M" : 16 ,
1954+ "BLOCK_SIZE_N" : 64 ,
1955+ "BLOCK_SIZE_K" : 128 ,
1956+ "GROUP_SIZE_M" : 1 ,
1957+ "num_warps" : 4 ,
1958+ "num_stages" : 4 ,
1959+ },
1960+ 64 : {
1961+ "BLOCK_SIZE_M" : 16 ,
1962+ "BLOCK_SIZE_N" : 64 ,
1963+ "BLOCK_SIZE_K" : 128 ,
1964+ "GROUP_SIZE_M" : 1 ,
1965+ "num_warps" : 4 ,
1966+ "num_stages" : 4 ,
1967+ },
1968+ 96 : {
1969+ "BLOCK_SIZE_M" : 16 ,
1970+ "BLOCK_SIZE_N" : 64 ,
1971+ "BLOCK_SIZE_K" : 128 ,
1972+ "GROUP_SIZE_M" : 1 ,
1973+ "num_warps" : 4 ,
1974+ "num_stages" : 3 ,
1975+ },
1976+ 128 : {
1977+ "BLOCK_SIZE_M" : 16 ,
1978+ "BLOCK_SIZE_N" : 64 ,
1979+ "BLOCK_SIZE_K" : 128 ,
1980+ "GROUP_SIZE_M" : 1 ,
1981+ "num_warps" : 4 ,
1982+ "num_stages" : 3 ,
1983+ },
1984+ 256 : {
1985+ "BLOCK_SIZE_M" : 32 ,
1986+ "BLOCK_SIZE_N" : 128 ,
1987+ "BLOCK_SIZE_K" : 64 ,
1988+ "GROUP_SIZE_M" : 1 ,
1989+ "num_warps" : 4 ,
1990+ "num_stages" : 5 ,
1991+ },
1992+ 512 : {
1993+ "BLOCK_SIZE_M" : 64 ,
1994+ "BLOCK_SIZE_N" : 256 ,
1995+ "BLOCK_SIZE_K" : 64 ,
1996+ "GROUP_SIZE_M" : 1 ,
1997+ "num_warps" : 8 ,
1998+ "num_stages" : 5 ,
1999+ },
2000+ 1024 : {
2001+ "BLOCK_SIZE_M" : 128 ,
2002+ "BLOCK_SIZE_N" : 256 ,
2003+ "BLOCK_SIZE_K" : 64 ,
2004+ "GROUP_SIZE_M" : 1 ,
2005+ "num_warps" : 8 ,
2006+ "num_stages" : 4 ,
2007+ },
2008+ 1536 : {
2009+ "BLOCK_SIZE_M" : 256 ,
2010+ "BLOCK_SIZE_N" : 256 ,
2011+ "BLOCK_SIZE_K" : 64 ,
2012+ "GROUP_SIZE_M" : 1 ,
2013+ "num_warps" : 8 ,
2014+ "num_stages" : 3 ,
2015+ },
2016+ 2048 : {
2017+ "BLOCK_SIZE_M" : 256 ,
2018+ "BLOCK_SIZE_N" : 256 ,
2019+ "BLOCK_SIZE_K" : 64 ,
2020+ "GROUP_SIZE_M" : 1 ,
2021+ "num_warps" : 8 ,
2022+ "num_stages" : 3 ,
2023+ },
2024+ 3072 : {
2025+ "BLOCK_SIZE_M" : 128 ,
2026+ "BLOCK_SIZE_N" : 256 ,
2027+ "BLOCK_SIZE_K" : 64 ,
2028+ "GROUP_SIZE_M" : 1 ,
2029+ "num_warps" : 8 ,
2030+ "num_stages" : 4 ,
2031+ },
2032+ 4096 : {
2033+ "BLOCK_SIZE_M" : 256 ,
2034+ "BLOCK_SIZE_N" : 256 ,
2035+ "BLOCK_SIZE_K" : 64 ,
2036+ "GROUP_SIZE_M" : 1 ,
2037+ "num_warps" : 8 ,
2038+ "num_stages" : 3 ,
2039+ },
2040+ }
2041+ best_key = min (_SM100_CONFIGS .keys (), key = lambda x : abs (x - M ))
2042+ return _SM100_CONFIGS [best_key ]
18892043
1890- # Tile sizes scale with batch: small batches are memory-bound
1891- # (favor tall-K tiles), large batches are compute-bound (favor
1892- # large M/N tiles with more warps).
2044+ # Default heuristic for all other GPUs (ported from vLLM)
18932045 if M <= 32 :
18942046 block_m = 16
18952047 elif M <= 96 :
@@ -1900,19 +2052,12 @@ def _get_default_config(self, M: int, E: int) -> dict:
19002052 block_m = 128
19012053
19022054 block_n = 64 if M <= 64 else 128
1903-
19042055 block_k = 64
19052056
1906- # Grouping adjacent M-blocks lets them share weight tiles in L2.
1907- # Only helps when there are enough M-blocks per expert to group;
1908- # with many experts each one sees few tokens so grouping is useless.
19092057 tokens_per_expert = M // max (E , 1 )
19102058 group_m = 16 if tokens_per_expert > 128 else 1
19112059
1912- # Large batches have enough blocks to saturate the GPU, so we
1913- # use more warps per block to increase arithmetic intensity.
19142060 num_warps = 4 if M <= 128 else 8
1915-
19162061 num_stages = 4 if M <= 32 else 3
19172062
19182063 return {
0 commit comments