@@ -1920,17 +1920,169 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
19201920
19211921 def _get_default_config (self , M : int , E : int ) -> dict :
19221922 """
1923- Heuristic tile config for BF16 MoE, ported verbatim from vLLM's
1924- `get_default_config` (bf16/fp16 non-block_shape branch).
1925- See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319.
1923+ GPU-aware heuristic tile config for BF16 MoE.
19261924
1927- M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count.
1925+ SM100 (B200): nearest-key lookup from SGLang tuned config
1926+ (triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json).
1927+ Others: original vLLM-ported heuristic.
1928+
1929+ M: number of tokens (pre-expansion token count).
19281930 E: number of (local) experts.
19291931 """
1932+ from fastdeploy .model_executor .utils import get_sm_version
1933+
1934+ if get_sm_version () >= 100 :
1935+ # SM100 (B200): use SGLang tuned lookup, nearest key by abs diff
1936+ _SM100_CONFIGS = {
1937+ 1 : {
1938+ "BLOCK_SIZE_M" : 16 ,
1939+ "BLOCK_SIZE_N" : 64 ,
1940+ "BLOCK_SIZE_K" : 64 ,
1941+ "GROUP_SIZE_M" : 32 ,
1942+ "num_warps" : 4 ,
1943+ "num_stages" : 5 ,
1944+ },
1945+ 2 : {
1946+ "BLOCK_SIZE_M" : 16 ,
1947+ "BLOCK_SIZE_N" : 64 ,
1948+ "BLOCK_SIZE_K" : 128 ,
1949+ "GROUP_SIZE_M" : 32 ,
1950+ "num_warps" : 4 ,
1951+ "num_stages" : 3 ,
1952+ },
1953+ 4 : {
1954+ "BLOCK_SIZE_M" : 16 ,
1955+ "BLOCK_SIZE_N" : 64 ,
1956+ "BLOCK_SIZE_K" : 128 ,
1957+ "GROUP_SIZE_M" : 64 ,
1958+ "num_warps" : 4 ,
1959+ "num_stages" : 4 ,
1960+ },
1961+ 8 : {
1962+ "BLOCK_SIZE_M" : 16 ,
1963+ "BLOCK_SIZE_N" : 64 ,
1964+ "BLOCK_SIZE_K" : 128 ,
1965+ "GROUP_SIZE_M" : 32 ,
1966+ "num_warps" : 4 ,
1967+ "num_stages" : 3 ,
1968+ },
1969+ 16 : {
1970+ "BLOCK_SIZE_M" : 16 ,
1971+ "BLOCK_SIZE_N" : 64 ,
1972+ "BLOCK_SIZE_K" : 128 ,
1973+ "GROUP_SIZE_M" : 1 ,
1974+ "num_warps" : 4 ,
1975+ "num_stages" : 3 ,
1976+ },
1977+ 24 : {
1978+ "BLOCK_SIZE_M" : 16 ,
1979+ "BLOCK_SIZE_N" : 128 ,
1980+ "BLOCK_SIZE_K" : 128 ,
1981+ "GROUP_SIZE_M" : 16 ,
1982+ "num_warps" : 4 ,
1983+ "num_stages" : 4 ,
1984+ },
1985+ 32 : {
1986+ "BLOCK_SIZE_M" : 16 ,
1987+ "BLOCK_SIZE_N" : 64 ,
1988+ "BLOCK_SIZE_K" : 128 ,
1989+ "GROUP_SIZE_M" : 16 ,
1990+ "num_warps" : 4 ,
1991+ "num_stages" : 4 ,
1992+ },
1993+ 48 : {
1994+ "BLOCK_SIZE_M" : 16 ,
1995+ "BLOCK_SIZE_N" : 64 ,
1996+ "BLOCK_SIZE_K" : 128 ,
1997+ "GROUP_SIZE_M" : 1 ,
1998+ "num_warps" : 4 ,
1999+ "num_stages" : 4 ,
2000+ },
2001+ 64 : {
2002+ "BLOCK_SIZE_M" : 16 ,
2003+ "BLOCK_SIZE_N" : 64 ,
2004+ "BLOCK_SIZE_K" : 128 ,
2005+ "GROUP_SIZE_M" : 1 ,
2006+ "num_warps" : 4 ,
2007+ "num_stages" : 4 ,
2008+ },
2009+ 96 : {
2010+ "BLOCK_SIZE_M" : 16 ,
2011+ "BLOCK_SIZE_N" : 64 ,
2012+ "BLOCK_SIZE_K" : 128 ,
2013+ "GROUP_SIZE_M" : 1 ,
2014+ "num_warps" : 4 ,
2015+ "num_stages" : 3 ,
2016+ },
2017+ 128 : {
2018+ "BLOCK_SIZE_M" : 16 ,
2019+ "BLOCK_SIZE_N" : 64 ,
2020+ "BLOCK_SIZE_K" : 128 ,
2021+ "GROUP_SIZE_M" : 1 ,
2022+ "num_warps" : 4 ,
2023+ "num_stages" : 3 ,
2024+ },
2025+ 256 : {
2026+ "BLOCK_SIZE_M" : 32 ,
2027+ "BLOCK_SIZE_N" : 128 ,
2028+ "BLOCK_SIZE_K" : 64 ,
2029+ "GROUP_SIZE_M" : 1 ,
2030+ "num_warps" : 4 ,
2031+ "num_stages" : 5 ,
2032+ },
2033+ 512 : {
2034+ "BLOCK_SIZE_M" : 64 ,
2035+ "BLOCK_SIZE_N" : 256 ,
2036+ "BLOCK_SIZE_K" : 64 ,
2037+ "GROUP_SIZE_M" : 1 ,
2038+ "num_warps" : 8 ,
2039+ "num_stages" : 5 ,
2040+ },
2041+ 1024 : {
2042+ "BLOCK_SIZE_M" : 128 ,
2043+ "BLOCK_SIZE_N" : 256 ,
2044+ "BLOCK_SIZE_K" : 64 ,
2045+ "GROUP_SIZE_M" : 1 ,
2046+ "num_warps" : 8 ,
2047+ "num_stages" : 4 ,
2048+ },
2049+ 1536 : {
2050+ "BLOCK_SIZE_M" : 256 ,
2051+ "BLOCK_SIZE_N" : 256 ,
2052+ "BLOCK_SIZE_K" : 64 ,
2053+ "GROUP_SIZE_M" : 1 ,
2054+ "num_warps" : 8 ,
2055+ "num_stages" : 3 ,
2056+ },
2057+ 2048 : {
2058+ "BLOCK_SIZE_M" : 256 ,
2059+ "BLOCK_SIZE_N" : 256 ,
2060+ "BLOCK_SIZE_K" : 64 ,
2061+ "GROUP_SIZE_M" : 1 ,
2062+ "num_warps" : 8 ,
2063+ "num_stages" : 3 ,
2064+ },
2065+ 3072 : {
2066+ "BLOCK_SIZE_M" : 128 ,
2067+ "BLOCK_SIZE_N" : 256 ,
2068+ "BLOCK_SIZE_K" : 64 ,
2069+ "GROUP_SIZE_M" : 1 ,
2070+ "num_warps" : 8 ,
2071+ "num_stages" : 4 ,
2072+ },
2073+ 4096 : {
2074+ "BLOCK_SIZE_M" : 256 ,
2075+ "BLOCK_SIZE_N" : 256 ,
2076+ "BLOCK_SIZE_K" : 64 ,
2077+ "GROUP_SIZE_M" : 1 ,
2078+ "num_warps" : 8 ,
2079+ "num_stages" : 3 ,
2080+ },
2081+ }
2082+ best_key = min (_SM100_CONFIGS .keys (), key = lambda x : abs (x - M ))
2083+ return _SM100_CONFIGS [best_key ]
19302084
1931- # Tile sizes scale with batch: small batches are memory-bound
1932- # (favor tall-K tiles), large batches are compute-bound (favor
1933- # large M/N tiles with more warps).
2085+ # Default heuristic for all other GPUs (ported from vLLM)
19342086 if M <= 32 :
19352087 block_m = 16
19362088 elif M <= 96 :
@@ -1941,19 +2093,12 @@ def _get_default_config(self, M: int, E: int) -> dict:
19412093 block_m = 128
19422094
19432095 block_n = 64 if M <= 64 else 128
1944-
19452096 block_k = 64
19462097
1947- # Grouping adjacent M-blocks lets them share weight tiles in L2.
1948- # Only helps when there are enough M-blocks per expert to group;
1949- # with many experts each one sees few tokens so grouping is useless.
19502098 tokens_per_expert = M // max (E , 1 )
19512099 group_m = 16 if tokens_per_expert > 128 else 1
19522100
1953- # Large batches have enough blocks to saturate the GPU, so we
1954- # use more warps per block to increase arithmetic intensity.
19552101 num_warps = 4 if M <= 128 else 8
1956-
19572102 num_stages = 4 if M <= 32 else 3
19582103
19592104 return {
0 commit comments