diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index be215065db3..d763afa6575 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1879,17 +1879,169 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): def _get_default_config(self, M: int, E: int) -> dict: """ - Heuristic tile config for BF16 MoE, ported verbatim from vLLM's - `get_default_config` (bf16/fp16 non-block_shape branch). - See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319. + GPU-aware heuristic tile config for BF16 MoE. - M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count. + SM100 (B200): nearest-key lookup from SGLang tuned config + (triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json). + Others: original vLLM-ported heuristic. + + M: number of tokens (pre-expansion token count). E: number of (local) experts. """ + from fastdeploy.model_executor.utils import get_sm_version + + if get_sm_version() >= 100: + # SM100 (B200): use SGLang tuned lookup, nearest key by abs diff + _SM100_CONFIGS = { + 1: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5, + }, + 2: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + }, + 4: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4, + }, + 8: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + }, + 16: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + }, + 24: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + }, + 32: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + }, + 48: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + }, + 64: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + }, + 96: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + }, + 128: { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + }, + 256: { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + }, + 512: { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5, + }, + 1024: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + }, + 1536: { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + }, + 2048: { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + }, + 3072: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + }, + 4096: { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + }, + } + best_key = min(_SM100_CONFIGS.keys(), key=lambda x: abs(x - M)) + return _SM100_CONFIGS[best_key] - # Tile sizes scale with batch: small batches are memory-bound - # (favor tall-K tiles), large batches are compute-bound (favor - # large M/N tiles with more warps). + # Default heuristic for all other GPUs (ported from vLLM) if M <= 32: block_m = 16 elif M <= 96: @@ -1900,19 +2052,12 @@ def _get_default_config(self, M: int, E: int) -> dict: block_m = 128 block_n = 64 if M <= 64 else 128 - block_k = 64 - # Grouping adjacent M-blocks lets them share weight tiles in L2. - # Only helps when there are enough M-blocks per expert to group; - # with many experts each one sees few tokens so grouping is useless. tokens_per_expert = M // max(E, 1) group_m = 16 if tokens_per_expert > 128 else 1 - # Large batches have enough blocks to saturate the GPU, so we - # use more warps per block to increase arithmetic intensity. num_warps = 4 if M <= 128 else 8 - num_stages = 4 if M <= 32 else 3 return { diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index b42db5cc3d3..08669db6b9f 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -964,98 +964,174 @@ def test_process_loaded_weights_stacks_experts(self): abs(actual_down - expected_down) < 1e-3 ), f"Expert {i} down_proj weight mean={actual_down}, expected {expected_down}" - # ------------------------------------------------------------------ # ------------------------------------------------------------------ # _get_default_config — tile heuristic # ------------------------------------------------------------------ - def test_get_default_config_decode(self): - """M<=32 decode path → 16x64x64.""" + def _mock_sm90(self, monkeypatch): + """Patch get_sm_version to return 90 (H100 / non-SM100 path). + + The function is decorated with @cache, so we must replace the cached + object on the module directly (the local import inside _get_default_config + re-fetches from the module each call, so setattr is sufficient). + """ + import fastdeploy.model_executor.utils as fd_utils + + monkeypatch.setattr(fd_utils, "get_sm_version", lambda: 90) + + def _mock_sm100(self, monkeypatch): + """Patch get_sm_version to return 100 (B200 / SM100 path).""" + import fastdeploy.model_executor.utils as fd_utils + + monkeypatch.setattr(fd_utils, "get_sm_version", lambda: 100) + + # -- SM90 (default heuristic) tests -- + + def test_get_default_config_decode(self, monkeypatch): + """SM90: M<=32 decode path → BLOCK_SIZE_M=16.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg = method._get_default_config(M=4, E=8) assert cfg["BLOCK_SIZE_M"] == 16 assert cfg["BLOCK_SIZE_N"] == 64 assert cfg["BLOCK_SIZE_K"] == 64 - def test_get_default_config_mid(self): - """96 < M <= 512 mid path → 64x128x64.""" + def test_get_default_config_mid(self, monkeypatch): + """SM90: 96 < M <= 512 mid path → BLOCK_SIZE_M=64.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg = method._get_default_config(M=128, E=8) assert cfg["BLOCK_SIZE_M"] == 64 assert cfg["BLOCK_SIZE_N"] == 128 assert cfg["BLOCK_SIZE_K"] == 64 - def test_get_default_config_prefill(self): - """M > 512 prefill path → 128x128x64.""" + def test_get_default_config_prefill(self, monkeypatch): + """SM90: M > 512 prefill path → BLOCK_SIZE_M=128.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg = method._get_default_config(M=1024, E=8) assert cfg["BLOCK_SIZE_M"] == 128 assert cfg["BLOCK_SIZE_N"] == 128 assert cfg["BLOCK_SIZE_K"] == 64 - def test_get_default_config_boundary_32(self): - """M==32 is decode (<=32).""" + def test_get_default_config_boundary_32(self, monkeypatch): + """SM90: M==32 is decode (<=32).""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg = method._get_default_config(M=32, E=8) assert cfg["BLOCK_SIZE_M"] == 16 - def test_get_default_config_boundary_96(self): - """M==96 is small-mid (32 < M <= 96) → BLOCK_SIZE_M=32.""" + def test_get_default_config_boundary_96(self, monkeypatch): + """SM90: M==96 is small-mid (32 < M <= 96) → BLOCK_SIZE_M=32.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg = method._get_default_config(M=96, E=8) assert cfg["BLOCK_SIZE_M"] == 32 - def test_get_default_config_boundary_512(self): - """M==512 is mid (<=512) → BLOCK_SIZE_M=64.""" + def test_get_default_config_boundary_512(self, monkeypatch): + """SM90: M==512 is mid (<=512) → BLOCK_SIZE_M=64.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg = method._get_default_config(M=512, E=8) assert cfg["BLOCK_SIZE_M"] == 64 - def test_get_default_config_has_group_size_m(self): - """All configs must include GROUP_SIZE_M key.""" + def test_get_default_config_has_group_size_m(self, monkeypatch): + """SM90: all configs must include GROUP_SIZE_M key.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() for M in (1, 64, 1024): cfg = method._get_default_config(M=M, E=8) assert "GROUP_SIZE_M" in cfg - def test_get_default_config_block_n_boundary(self): - """M<=64 → BLOCK_SIZE_N=64; M>64 → BLOCK_SIZE_N=128.""" + def test_get_default_config_block_n_boundary(self, monkeypatch): + """SM90: M<=64 → BLOCK_SIZE_N=64; M>64 → BLOCK_SIZE_N=128.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg64 = method._get_default_config(M=64, E=8) assert cfg64["BLOCK_SIZE_N"] == 64 cfg65 = method._get_default_config(M=65, E=8) assert cfg65["BLOCK_SIZE_N"] == 128 - def test_get_default_config_group_m_16(self): - """tokens_per_expert > 128 → GROUP_SIZE_M=16.""" + def test_get_default_config_group_m_16(self, monkeypatch): + """SM90: tokens_per_expert > 128 → GROUP_SIZE_M=16.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() # M=1024, E=1 → tokens_per_expert=1024 > 128 → group_m=16 cfg = method._get_default_config(M=1024, E=1) assert cfg["GROUP_SIZE_M"] == 16 - def test_get_default_config_group_m_1(self): - """tokens_per_expert <= 128 → GROUP_SIZE_M=1.""" + def test_get_default_config_group_m_1(self, monkeypatch): + """SM90: tokens_per_expert <= 128 → GROUP_SIZE_M=1.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() # M=128, E=8 → tokens_per_expert=16 <= 128 → group_m=1 cfg = method._get_default_config(M=128, E=8) assert cfg["GROUP_SIZE_M"] == 1 - def test_get_default_config_num_warps(self): - """M<=128 → num_warps=4; M>128 → num_warps=8.""" + def test_get_default_config_num_warps(self, monkeypatch): + """SM90: M<=128 → num_warps=4; M>128 → num_warps=8.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg128 = method._get_default_config(M=128, E=8) assert cfg128["num_warps"] == 4 cfg256 = method._get_default_config(M=256, E=8) assert cfg256["num_warps"] == 8 - def test_get_default_config_num_stages(self): - """M<=32 → num_stages=4; M>32 → num_stages=3.""" + def test_get_default_config_num_stages(self, monkeypatch): + """SM90: M<=32 → num_stages=4; M>32 → num_stages=3.""" + self._mock_sm90(monkeypatch) method = backend.TritonMoEMethod() cfg32 = method._get_default_config(M=32, E=8) assert cfg32["num_stages"] == 4 cfg33 = method._get_default_config(M=33, E=8) assert cfg33["num_stages"] == 3 + # -- SM100 (B200 lookup table) tests -- + + def test_get_default_config_sm100_small_batch(self, monkeypatch): + """SM100: M=8 → nearest key=8, BLOCK_SIZE_M=16, BLOCK_SIZE_N=64, BLOCK_SIZE_K=128.""" + self._mock_sm100(monkeypatch) + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=8, E=64) + assert cfg["BLOCK_SIZE_M"] == 16 + assert cfg["BLOCK_SIZE_N"] == 64 + assert cfg["BLOCK_SIZE_K"] == 128 + + def test_get_default_config_sm100_large_batch(self, monkeypatch): + """SM100: M>=1536 → BLOCK_SIZE_M=256 (B200 advantage over H100's 128).""" + self._mock_sm100(monkeypatch) + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=2048, E=64) + assert cfg["BLOCK_SIZE_M"] == 256 + assert cfg["BLOCK_SIZE_N"] == 256 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_sm100_nearest_key(self, monkeypatch): + """SM100: M=100 is equidistant between 96 and 128; should pick one of them.""" + self._mock_sm100(monkeypatch) + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=100, E=64) + # nearest key between 96 and 128: abs(96-100)=4, abs(128-100)=28 → picks 96 + assert cfg["BLOCK_SIZE_M"] == 16 + assert cfg["BLOCK_SIZE_K"] == 128 + + def test_get_default_config_sm100_mid_range(self, monkeypatch): + """SM100: M=512 → BLOCK_SIZE_M=64, BLOCK_SIZE_N=256.""" + self._mock_sm100(monkeypatch) + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=512, E=64) + assert cfg["BLOCK_SIZE_M"] == 64 + assert cfg["BLOCK_SIZE_N"] == 256 + + def test_get_default_config_sm100_all_keys_present(self, monkeypatch): + """SM100: every returned config must have all 6 required keys.""" + self._mock_sm100(monkeypatch) + method = backend.TritonMoEMethod() + required = {"BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K", "GROUP_SIZE_M", "num_warps", "num_stages"} + for M in (1, 16, 64, 256, 1024, 4096): + cfg = method._get_default_config(M=M, E=64) + assert required.issubset(cfg.keys()), f"Missing keys at M={M}: {required - set(cfg.keys())}" + # ------------------------------------------------------------------ # apply — empty-batch fast path # ------------------------------------------------------------------ @@ -1161,29 +1237,61 @@ def test_apply_tp_delegates_to_apply(self, fake_ops, monkeypatch): assert list(out.shape) == [2, layer.hidden_size] # ------------------------------------------------------------------ - # EP methods raise NotImplementedError + # EP methods (not yet implemented) # ------------------------------------------------------------------ def test_apply_ep_prefill_raises(self): + """apply_ep_prefill raises NotImplementedError until EP is implemented.""" method = backend.TritonMoEMethod() layer = self._make_layer() with pytest.raises(NotImplementedError): method.apply_ep_prefill(layer, None, None) def test_apply_ep_decode_raises(self): + """apply_ep_decode raises NotImplementedError until EP is implemented.""" method = backend.TritonMoEMethod() layer = self._make_layer() with pytest.raises(NotImplementedError): method.apply_ep_decode(layer, None, None) + def test_apply_tp_calls_kernel_twice(self, fake_ops, monkeypatch): + """apply_tp must invoke fused_moe_kernel_bf16 exactly twice (GEMM1 + GEMM2).""" + import fastdeploy.model_executor.utils as fd_utils + + monkeypatch.setattr(fd_utils, "get_sm_version", lambda: 90) + + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply_tp(layer, x, gate) + + assert len(kernel.calls) == 2 + # ------------------------------------------------------------------ # apply — kernel argument verification # ------------------------------------------------------------------ def test_apply_kernel_even_ks_true(self, fake_ops, monkeypatch): - """When hidden_size is divisible by BLOCK_SIZE_K, even_Ks=True in GEMM1.""" + """When hidden_size is divisible by BLOCK_SIZE_K, even_Ks=True in GEMM1. + + Patch _get_default_config at class level to fix BLOCK_SIZE_K=64 so the + test is independent of which GPU model (SM90 vs SM100) is running. + """ + _fixed_cfg = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + } + monkeypatch.setattr(backend.TritonMoEMethod, "_get_default_config", lambda self, M, E: _fixed_cfg) method = backend.TritonMoEMethod() - # hidden_size=64, BLOCK_SIZE_K=64 → even_Ks=True for GEMM1 + # hidden_size=64 % BLOCK_SIZE_K=64 == 0 → even_Ks=True layer = self._make_layer(hidden_size=64, intermediate_size=32) self._create_weights(method, layer) kernel = self._patch_bf16_kernel(monkeypatch) @@ -1197,8 +1305,17 @@ def test_apply_kernel_even_ks_true(self, fake_ops, monkeypatch): def test_apply_kernel_even_ks_false(self, fake_ops, monkeypatch): """When hidden_size is NOT divisible by BLOCK_SIZE_K, even_Ks=False in GEMM1.""" + _fixed_cfg = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + } + monkeypatch.setattr(backend.TritonMoEMethod, "_get_default_config", lambda self, M, E: _fixed_cfg) method = backend.TritonMoEMethod() - # hidden_size=8, BLOCK_SIZE_K=64 → even_Ks=False for GEMM1 + # hidden_size=8 % BLOCK_SIZE_K=64 != 0 → even_Ks=False layer = self._make_layer(hidden_size=8, intermediate_size=4) self._create_weights(method, layer) kernel = self._patch_bf16_kernel(monkeypatch) @@ -1240,13 +1357,20 @@ def test_apply_gemm1_no_mul_weight_gemm2_mul_weight(self, fake_ops, monkeypatch) assert kernel.calls[1]["kwargs"]["MUL_ROUTED_WEIGHT"] is True def test_apply_large_batch_config(self, fake_ops, monkeypatch): - """Large token count picks larger tile config (BLOCK_SIZE_M=128, num_warps=8).""" + """Large token count picks larger tile config (BLOCK_SIZE_M=128, num_warps=8). + + Force SM90 config so the expectation is GPU-model-independent. + """ + import fastdeploy.model_executor.utils as fd_utils + + monkeypatch.setattr(fd_utils, "get_sm_version", lambda: 90) + method = backend.TritonMoEMethod() layer = self._make_layer(hidden_size=8) self._create_weights(method, layer) kernel = self._patch_bf16_kernel(monkeypatch) - # 1024 tokens → prefill config: BLOCK_SIZE_M=128 + # 1024 tokens → SM90 prefill config: BLOCK_SIZE_M=128, num_warps=8 x = paddle.randn([1024, layer.hidden_size], dtype="bfloat16") gate = DummyGate(layer.num_local_experts) method.apply(layer, x, gate)