diff --git a/docker/install.sh b/docker/install.sh index 5fec95ebbf..5bd9cef2b0 100755 --- a/docker/install.sh +++ b/docker/install.sh @@ -53,7 +53,7 @@ else fi pip install /wheels/*.whl -pip install dlblas==0.0.7 dlslime==0.0.2.post1 +pip install dlslime==0.0.2.post1 pip install ninja einops packaging diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index e82eefbdcb..bd2397c14b 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -308,7 +308,7 @@ def prepare_inputs_for_generation( """Prepare inputs.""" if get_deepep_state().enabled(): - from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPBuffer, DeepEPMode deepep_mode = DeepEPMode.LOW_LATENCY if context.global_is_decoding() else DeepEPMode.NORMAL DeepEPBuffer.set_deepep_mode(deepep_mode) @@ -322,7 +322,7 @@ def reset(self): """Remove all graphs to prevent hanging on exit.""" self._runner_map.clear() if get_deepep_state().enabled(): - from dlblas.layers.moe.token_dispatcher import DeepEPBuffer + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPBuffer if hasattr(DeepEPBuffer, 'destroy'): from torch import distributed as dist diff --git a/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py b/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py index f08530006d..22eb436d27 100644 --- a/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py +++ b/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py @@ -5,12 +5,21 @@ import torch import torch.distributed as dist +from lmdeploy.pytorch.backends.cuda.token_dispatcher import ( + DeepEPBuffer, + DeepEPTokenDispatcherLowLatency, + DeepEPTokenDispatcherNormal, + DisposibleTensor, + use_deepep, +) from lmdeploy.pytorch.backends.deepep_state import get_deepep_state from lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl from lmdeploy.pytorch.distributed import get_dist_manager +from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul_masked_post_quant_fwd from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 -from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import per_token_group_quant_fp8, quant_fp8 from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize +from lmdeploy.pytorch.kernels.cuda.fused_moe_ep_fp8 import fused_moe_v3_fp8 from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from lmdeploy.utils import get_logger @@ -19,6 +28,258 @@ logger = get_logger('lmdeploy') +class FusedMoENormal: + + def __init__( + self, + ep_size: int, + ep_group: dist.ProcessGroup, + num_experts: int, + hidden_dim: int, + layer_index: int = 0, + block_size: int = 128, + top_k: int = 8, + out_dtype: torch.dtype = torch.bfloat16, + fp8_dtype: torch.dtype | None = None, + scale_fmt: str | None = None, + num_max_dispatch_tokens_per_rank: int = 128, + chunk_size: int | None = 32 * 1024, + expert_alignment: int = 128, + ): + self.layer_index = layer_index + self.top_k = top_k + self.num_experts = num_experts + self.block_size = block_size + self.num_local_experts = num_experts // ep_size + self.out_dtype = out_dtype + self.fp8_dtype = fp8_dtype + self.scale_fmt = scale_fmt + self.token_dispatcher = DeepEPTokenDispatcherNormal( + group=ep_group, + num_experts=num_experts, + num_local_experts=self.num_local_experts, + hidden_size=hidden_dim, + params_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + expert_alignment=expert_alignment, + ) + + def forward( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + up_weights: torch.Tensor, + up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: list[int] = None, + ): + hs_quant, hs_scale = per_token_group_quant_fp8(hidden_states, + self.block_size, + dtype=up_weights.dtype, + scale_fmt=self.scale_fmt) + x, recv_topk_ids, recv_topk_weights, recv_tokens_per_expert = self.token_dispatcher.dispatch( + (hs_quant, hs_scale), + topk_ids, + topk_weights, + expert_list, + ) + out_states = fused_moe_v3_fp8(x, recv_topk_ids, recv_topk_weights, (up_weights, up_scale), + (down_weights, down_scale), recv_tokens_per_expert) + return self.token_dispatcher.combine(out_states) + + def capture(self): + return self.token_dispatcher.buffer_normal.capture() + + def wait(self, event): + self.token_dispatcher.release() + event.current_stream_wait() + + def dispatch_async(self, + x: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int | None = None, + previous_event=None, + async_finish=True): + if isinstance(x, torch.Tensor): + x = self.per_token_group_quant_fp8(x) + return self.token_dispatcher.dispatch_normal_async(x, topk_idx, topk_weights, num_experts, previous_event, + async_finish) + + def combine_async(self, x: torch.Tensor, handle: tuple, previous_event=None, async_finish=True): + return self.token_dispatcher.combine_normal_async(x, handle, previous_event, async_finish) + + def release(self): + return self.token_dispatcher.release() + + def fusedmoe_forward(self, state, up_weight, up_scale, down_weight, down_scale): + return fused_moe_v3_fp8(state['recv_hidden_states'], state['recv_topk_idx'], state['recv_topk_weights'], + (up_weight, up_scale), (down_weight, down_scale), state['recv_tokens_per_expert']) + + def per_token_group_quant_fp8(self, + x: torch.Tensor, + dtype: torch.dtype | None = None, + scale_fmt: str | None = None): + dtype = dtype if dtype is not None else self.fp8_dtype + scale_fmt = scale_fmt if scale_fmt is not None else self.scale_fmt + return per_token_group_quant_fp8(x, self.block_size, dtype=dtype, scale_fmt=scale_fmt) + + +class FusedMoELowLatency: + + def __init__( + self, + ep_size: int, + ep_group: dist.ProcessGroup, + num_experts: int, + hidden_dim: int, + layer_index: int, + block_size: int = 128, + out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128, + ): + self.num_experts = num_experts + self.layer_index = layer_index + self.block_size = block_size + self.out_dtype = out_dtype + self.token_dispatcher = DeepEPTokenDispatcherLowLatency( + group=ep_group, + num_experts=num_experts, + num_local_experts=num_experts // ep_size, + hidden_size=hidden_dim, + params_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + ) + + def _deepgemm_grouped_fp8_nt_masked(self, input_tuple, w_tuple, out: torch.Tensor, masked_m: torch.Tensor, + expected_m: int): + from lmdeploy.pytorch.third_party.deep_gemm import m_grouped_fp8_gemm_nt_masked + return m_grouped_fp8_gemm_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m) + + def experts( + self, + hidden_states_fp8, + gate_up_weight: torch.Tensor, + gate_up_scale: torch.Tensor, + gate_down_weight: torch.Tensor, + gate_down_scale: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + ): + gate_up_weight_fp8 = (gate_up_weight, gate_up_scale) + gate_down_weight_fp8 = (gate_down_weight, gate_down_scale) + num_groups, m, _ = hidden_states_fp8[0].shape + n = gate_up_weight.size(1) + expected_m = min(expected_m, m) + gateup_output = torch.empty((num_groups, m, n), device=hidden_states_fp8[0].device, dtype=self.out_dtype) + self._deepgemm_grouped_fp8_nt_masked([DisposibleTensor.maybe_unwrap(x) for x in hidden_states_fp8], + gate_up_weight_fp8, gateup_output, masked_m, expected_m) + DisposibleTensor.maybe_dispose(hidden_states_fp8[0]) + DisposibleTensor.maybe_dispose(hidden_states_fp8[1]) + down_input = torch.empty((gateup_output.shape[0], gateup_output.shape[1], gateup_output.shape[2] // 2), + device=gateup_output.device, + dtype=gate_down_weight.dtype) + down_input_scale = torch.empty( + (gateup_output.shape[0], gateup_output.shape[1], gateup_output.shape[2] // 2 // self.block_size), + device=gateup_output.device, + dtype=torch.float32) + silu_and_mul_masked_post_quant_fwd(gateup_output, down_input, down_input_scale, self.block_size, masked_m) + del gateup_output + down_output = torch.empty((num_groups, m, gate_down_weight.size(1)), + device=down_input.device, + dtype=self.out_dtype) + self._deepgemm_grouped_fp8_nt_masked((down_input, down_input_scale), gate_down_weight_fp8, down_output, + masked_m, expected_m) + return down_output + + def forward( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + up_weights: torch.Tensor, + up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: list[int] = None, + ): + recv_hidden_states, topk_idx, topk_weights, masked_m, expected_m = self.token_dispatcher.dispatch( + hidden_states, topk_ids, topk_weights, self.num_experts) + out_states = self.experts(recv_hidden_states, up_weights, up_scale, down_weights, down_scale, masked_m, + expected_m) + return self.token_dispatcher.combine(out_states, topk_idx, topk_weights) + + def wait(self, event): + event.current_stream_wait() + + def dispatch_async(self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + num_experts: int | None = None, + use_fp8: bool = True, + async_finish: bool = True): + return self.token_dispatcher.dispatch_async(hidden_states, topk_idx, num_experts, use_fp8, async_finish) + + def combine_async(self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: tuple, + async_finish: bool): + return self.token_dispatcher.combine_async(hidden_states, topk_idx, topk_weights, handle, async_finish) + + def fusedmoe_forward(self, state, up_weight, up_scale, down_weight, down_scale): + recv_hidden_states = state['recv_hidden_states'] + masked_m = state['recv_expert_count'] + hidden_shape = state['raw_hidden_shape'] + topk_idx = state['topk_idx'] + expected_m = (hidden_shape[0] * self.token_dispatcher.buffer_low_latency.group_size * topk_idx.shape[1] + + self.token_dispatcher.num_experts) // self.token_dispatcher.num_experts + return self.experts(recv_hidden_states, up_weight, up_scale, down_weight, down_scale, masked_m, expected_m) + + +def build_deepep_moe( + low_latency_mode: bool, + ep_size: int, + ep_group: dist.ProcessGroup, + num_experts: int, + hidden_dim: int, + block_size: int, + top_k: int, + out_dtype: torch.dtype, + fp8_dtype: torch.dtype | None = None, + scale_fmt: str | None = None, + layer_idx: int = 0, + num_max_dispatch_tokens_per_rank: int = 128, + chunk_size: int | None = 32 * 1024, + expert_alignment: int = 128, +): + if low_latency_mode: + return FusedMoELowLatency(ep_size=ep_size, + ep_group=ep_group, + num_experts=num_experts, + hidden_dim=hidden_dim, + layer_index=layer_idx, + block_size=block_size, + out_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank) + return FusedMoENormal(ep_size=ep_size, + ep_group=ep_group, + num_experts=num_experts, + hidden_dim=hidden_dim, + layer_index=layer_idx, + block_size=block_size, + top_k=top_k, + out_dtype=out_dtype, + fp8_dtype=fp8_dtype, + scale_fmt=scale_fmt, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + chunk_size=chunk_size, + expert_alignment=expert_alignment) + + class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl): """Triton fused moe blocked f8 implementation.""" @@ -98,6 +359,8 @@ def __init__(self, renormalize: bool = False, block_size: int = 128, out_dtype: torch.dtype = torch.bfloat16, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + num_max_dispatch_tokens_per_rank: int = 128, layer_idx: int = 0): super().__init__(top_k, num_experts, renormalize, block_size, out_dtype) self.num_experts = num_experts @@ -106,21 +369,22 @@ def __init__(self, self.hidden_dim = hidden_dim self.block_size = block_size self.out_dtype = out_dtype + self.fp8_dtype = fp8_dtype + self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank self.layer_idx = layer_idx try: import deep_gemm # noqa: F401 self.use_deep_gemm = True except ImportError: - self.use_deep_gemm = False - logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') + logger.exception('DeepGEMM is required for DeepEP MoE implementation.') + raise - try: - from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep # noqa: F401 - get_deepep_state().enable() - if hasattr(DeepEPBuffer, 'set_explicitly_destroy'): - DeepEPBuffer.set_explicitly_destroy() - except ImportError: - logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP') + if not use_deepep: + raise ImportError('DeepEP is required for DeepEP MoE implementation. Please install ' + 'https://github.com/deepseek-ai/DeepEP.') + get_deepep_state().enable() + if hasattr(DeepEPBuffer, 'set_explicitly_destroy'): + DeepEPBuffer.set_explicitly_destroy() # pre-allocate buffer self.fusedmoe_build(True) @@ -128,7 +392,7 @@ def __init__(self, def ep_expert_list(self, world_size: int, rank: int): """Experts list of current rank.""" if get_dist_manager().current_context().dist_config.enable_eplb: - from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer + from lmdeploy.pytorch.nn.eplb import get_eplb_phy2log_metadata_by_layer phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx) expert_per_rank = (self.num_experts + world_size - 1) // world_size first_expert = rank * expert_per_rank @@ -169,7 +433,6 @@ def do_renormalize(self, topk_weights): return _renormalize(topk_weights, self.renormalize) def fusedmoe_build(self, low_latency_mode: bool = False): - from dlblas.layers.moe.ep_moe import build_deepep_moe deepep_moe = build_deepep_moe(low_latency_mode, self.ep_size, self.ep_group, @@ -178,7 +441,10 @@ def fusedmoe_build(self, low_latency_mode: bool = False): self.block_size, self.top_k, self.out_dtype, + fp8_dtype=self.fp8_dtype, + scale_fmt=self.scale_fmt, layer_idx=self.layer_idx, + num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank, chunk_size=16 * 1024) return deepep_moe @@ -195,6 +461,8 @@ def build(top_k: int, ep_size: int = 1, ep_group: dist.ProcessGroup = None, out_dtype: torch.dtype = torch.float16, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + num_max_dispatch_tokens_per_rank: int = 128, layer_idx: int = 0, custom_gateup_act: bool = False): """Build from mlp.""" @@ -208,6 +476,8 @@ def build(top_k: int, renormalize=renormalize, block_size=block_size, out_dtype=out_dtype, + fp8_dtype=fp8_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, layer_idx=layer_idx) else: return TritonFusedMoEBlockedF8Impl(top_k=top_k, diff --git a/lmdeploy/pytorch/backends/cuda/moe/default.py b/lmdeploy/pytorch/backends/cuda/moe/default.py index 448f2ac257..400dd6b820 100644 --- a/lmdeploy/pytorch/backends/cuda/moe/default.py +++ b/lmdeploy/pytorch/backends/cuda/moe/default.py @@ -81,8 +81,9 @@ def __init__( layer_index: int = 0, top_k: int = 8, out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128, ): - from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherNormal + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPTokenDispatcherNormal self.layer_index = layer_index self.top_k = top_k self.num_experts = num_experts @@ -94,6 +95,7 @@ def __init__( num_local_experts=self.num_local_experts, hidden_size=hidden_dim, params_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, ) def forward( @@ -148,7 +150,7 @@ def fusedmoe_forward(self, state, up_weight, down_weight): def _disposible_tensor(tensor): - from dlblas.utils.utils import DisposibleTensor + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DisposibleTensor if isinstance(tensor, torch.Tensor): tensor = DisposibleTensor(tensor) else: @@ -237,8 +239,9 @@ def __init__( hidden_dim: int, layer_index: int, out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128, ): - from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherLowLatency + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPTokenDispatcherLowLatency self.num_experts = num_experts self.layer_index = layer_index self.out_dtype = out_dtype @@ -248,6 +251,7 @@ def __init__( num_local_experts=num_experts // ep_size, hidden_size=hidden_dim, params_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, ) def experts( @@ -258,8 +262,7 @@ def experts( masked_m: torch.Tensor, expected_m: int, ): - from dlblas.utils.utils import DisposibleTensor - + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DisposibleTensor from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul_moe_ep from lmdeploy.pytorch.third_party.deep_gemm import m_grouped_bf16_gemm_nt_masked num_groups, m, _ = hidden_states.shape @@ -339,6 +342,7 @@ def build_deepep_moe( top_k: int, layer_idx: int = 0, out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128, ): if low_latency_mode: return FusedMoELowLatency(ep_size=ep_size, @@ -346,7 +350,8 @@ def build_deepep_moe( num_experts=num_experts, hidden_dim=hidden_dim, layer_index=layer_idx, - out_dtype=out_dtype) + out_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank) else: return FusedMoENormal(ep_size=ep_size, ep_group=ep_group, @@ -354,7 +359,8 @@ def build_deepep_moe( hidden_dim=hidden_dim, layer_index=layer_idx, top_k=top_k, - out_dtype=out_dtype) + out_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank) class FusedMoEEPImpl(TritonFusedMoEImpl): @@ -370,6 +376,7 @@ def __init__( renormalize: bool = False, layer_idx: int = 0, out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128, ): super().__init__(top_k, num_experts, renormalize) self.num_experts = num_experts @@ -378,19 +385,21 @@ def __init__( self.hidden_dim = hidden_dim self.layer_idx = layer_idx self.out_dtype = out_dtype + self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank try: import deep_gemm # noqa: F401 except ImportError: logger.exception('DeepGEMM is required for DeepEP MoE implementation.') + raise - try: - from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep # noqa: F401 - get_deepep_state().enable() - if hasattr(DeepEPBuffer, 'set_explicitly_destroy'): - DeepEPBuffer.set_explicitly_destroy() - except ImportError: - logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP') + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPBuffer, use_deepep + if not use_deepep: + raise ImportError('DeepEP is required for DeepEP MoE implementation. Please install ' + 'https://github.com/deepseek-ai/DeepEP.') + get_deepep_state().enable() + if hasattr(DeepEPBuffer, 'set_explicitly_destroy'): + DeepEPBuffer.set_explicitly_destroy() # pre-allocate buffer self.fusedmoe_build(True) @@ -440,7 +449,8 @@ def fusedmoe_build(self, low_latency_mode: bool = False): self.hidden_dim, self.top_k, layer_idx=self.layer_idx, - out_dtype=self.out_dtype) + out_dtype=self.out_dtype, + num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank) return deepep_moe @@ -457,6 +467,7 @@ def build( ep_group: dist.ProcessGroup = None, layer_idx: int = 0, out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128, ): """Build from mlp.""" if ep_size > 1: @@ -467,5 +478,6 @@ def build( hidden_dim=hidden_dim, renormalize=renormalize, layer_idx=layer_idx, - out_dtype=out_dtype) + out_dtype=out_dtype, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank) return TritonFusedMoEImpl(top_k=top_k, num_experts=num_experts, renormalize=renormalize) diff --git a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py index 4fe4d1f897..7ea3353919 100644 --- a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py +++ b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py @@ -1,12 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os +import sys +from enum import Enum + +from lmdeploy.pytorch.envs import deep_ep_buffer_num_sms, env_to_int + try: from deep_ep import Buffer - from lmdeploy.pytorch.envs import deep_ep_buffer_num_sms, deep_ep_max_tokens_per_rank - Buffer.set_num_sms(deep_ep_buffer_num_sms) use_deepep = True except ImportError: + Buffer = None use_deepep = False @@ -16,41 +21,251 @@ from ..default.token_dispatcher import AlltoAllTokenDispatcher from ..token_dispatcher import TokenDispatcherImpl -_buffer_normal = None -_buffer_low_latency = None -_buffer_common = None +class DeepEPMode(Enum): + """DeepEP communication mode.""" -def get_buffer_common( - group: dist.ProcessGroup, - num_max_dispatch_tokens_per_rank: int, - hidden: int, - num_experts: int, - hidden_bytes: int, -): - global _buffer_common - num_nvl_bytes, num_rdma_bytes = 0, 0 - for config in ( - Buffer.get_dispatch_config(group.size()), - Buffer.get_combine_config(group.size()), - ): - num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) - num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) + NORMAL = 'normal' + LOW_LATENCY = 'low_latency' + AUTO = 'auto' + + +class DisposibleTensor: + """Tensor wrapper that allows eager disposal while preserving metadata.""" + + def __init__(self, value: torch.Tensor): + self._value = value + self._backup_metadata = None + + @property + def value(self): + assert not self.is_disposed + return self._value - num_rdma_bytes = max( - Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts), - num_rdma_bytes) + @property + def is_disposed(self): + return self._value is None + + def dispose(self, backup_metadata: bool = True): + assert not self.is_disposed + if not torch.compiler.is_compiling() and sys.getrefcount(self._value) != 2: + return + if backup_metadata: + self._backup_metadata = {key: getattr(self._value, key) for key in ('shape', 'device', 'dtype')} + self._value = None + + @staticmethod + def maybe_unwrap(value): + return value.value if isinstance(value, DisposibleTensor) else value + + @staticmethod + def maybe_dispose(value): + if isinstance(value, DisposibleTensor): + value.dispose() + + @property + def shape(self): + return self._get_metadata('shape') + + @property + def device(self): + return self._get_metadata('device') + + @property + def dtype(self): + return self._get_metadata('dtype') + + def _get_metadata(self, name: str): + if not self.is_disposed: + return getattr(self._value, name) + assert self._backup_metadata is not None + return self._backup_metadata[name] + + +class DeepEPBuffer: + """LMDeploy-owned DeepEP buffer facade.""" + + _buffer_normal = None + _buffer_low_latency = None + _buffer_common = None + _deepep_mode = DeepEPMode.AUTO + _deepep_sms = deep_ep_buffer_num_sms if use_deepep else 20 + _num_max_dispatch_tokens_per_rank = 128 + _allow_mnnvl = True + _latest_mode = DeepEPMode.AUTO + _hidden_size = -1 + _num_experts = -1 + _explicitly_destroy = False + + @classmethod + def _build_buffer(cls, *args, **kwargs): + """Build a DeepEP Buffer while tolerating older constructor + signatures.""" + try: + return Buffer(*args, **kwargs) + except TypeError: + kwargs.pop('allow_mnnvl', None) + kwargs.pop('explicitly_destroy', None) + return Buffer(*args, **kwargs) + + @classmethod + def set_explicitly_destroy(cls): + if cls._buffer_common is not None or cls._buffer_normal is not None or cls._buffer_low_latency is not None: + return False + if not cls._explicitly_destroy: + cls._explicitly_destroy = True + return True + return False + + @classmethod + def get_explicitly_destroy(cls): + return cls._explicitly_destroy + + @classmethod + def destroy(cls): + if not cls._explicitly_destroy: + return False + if cls._buffer_common is not None: + cls._buffer_common.destroy() + cls._buffer_common = None + return True + if cls._buffer_low_latency is not None: + cls._buffer_low_latency.destroy() + cls._buffer_low_latency = None + return True + if cls._buffer_normal is not None: + cls._buffer_normal.destroy() + cls._buffer_normal = None + return True + return False + + @classmethod + def update_parameters(cls, hidden_size: int, num_experts: int): + cls._hidden_size = hidden_size + cls._num_experts = num_experts + cls._deepep_sms = env_to_int('DEEPEP_BUFFER_NUM_SMS', cls._deepep_sms) + cls._allow_mnnvl = os.getenv('DEEPEP_ENABLE_MNNVL', '1') != '0' + env_mode = os.getenv('DEEPEP_MODE', 'auto').strip().lower() + if env_mode == 'normal': + cls._deepep_mode = DeepEPMode.NORMAL + elif env_mode == 'low_latency': + cls._deepep_mode = DeepEPMode.LOW_LATENCY + else: + cls._deepep_mode = DeepEPMode.AUTO + + @classmethod + def set_deepep_mode(cls, mode: DeepEPMode): + low_latency_buffer_cleaned = False + if (cls._deepep_mode == DeepEPMode.AUTO and mode == DeepEPMode.LOW_LATENCY + and cls._latest_mode == DeepEPMode.NORMAL): + cls.clean_low_latency_buffer(cls._buffer_common) + low_latency_buffer_cleaned = True + cls._latest_mode = mode + return cls._latest_mode, low_latency_buffer_cleaned + + @classmethod + def clean_low_latency_buffer(cls, buffer=None): + if buffer is None: + buffer = cls._buffer_common + if use_deepep and isinstance(buffer, Buffer): + buffer.clean_low_latency_buffer(cls._num_max_dispatch_tokens_per_rank, cls._hidden_size, cls._num_experts) + + @classmethod + def get_buffer_common( + cls, + group: dist.ProcessGroup, + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_experts: int, + hidden_bytes: int, + ): + if cls._buffer_common is not None: + # Match dlblas/DeepEP's process-wide common buffer lifetime. + return cls._buffer_common + + cls.update_parameters(hidden, num_experts) + num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank or cls._num_max_dispatch_tokens_per_rank + cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) + + num_rdma_bytes = max( + Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts), + num_rdma_bytes) - if (_buffer_common is None or _buffer_common.group != group or _buffer_common.num_nvl_bytes < num_nvl_bytes - or _buffer_common.num_rdma_bytes < num_rdma_bytes): - _buffer_common = Buffer( + assert num_experts % group.size( + ) == 0, f'num_experts: {num_experts} must be divisible by ep_size: {group.size()}' + cls._buffer_common = cls._build_buffer( group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=max(num_experts // group.size(), Buffer.num_sms // 2), + num_qps_per_rank=max(num_experts // group.size(), cls._deepep_sms), + allow_mnnvl=cls._allow_mnnvl, + explicitly_destroy=cls._explicitly_destroy, ) - return _buffer_common + cls._buffer_common.set_num_sms(cls._deepep_sms) + return cls._buffer_common + + @classmethod + def get_buffer_normal(cls, group: dist.ProcessGroup, hidden_bytes: int): + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) + + if (cls._buffer_normal is None or cls._buffer_normal.group != group + or cls._buffer_normal.num_nvl_bytes < num_nvl_bytes + or cls._buffer_normal.num_rdma_bytes < num_rdma_bytes): + cls._buffer_normal = cls._build_buffer(group, + num_nvl_bytes, + num_rdma_bytes, + explicitly_destroy=cls._explicitly_destroy) + return cls._buffer_normal + + @classmethod + def get_buffer_low_latency( + cls, + group: dist.ProcessGroup, + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_experts: int, + ): + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), + num_experts) + + if (cls._buffer_low_latency is None or cls._buffer_low_latency.group != group + or not cls._buffer_low_latency.low_latency_mode + or cls._buffer_low_latency.num_rdma_bytes < num_rdma_bytes): + assert num_experts % group.size( + ) == 0, f'num_experts: {num_experts} must be divisible by ep_size: {group.size()}' + cls._buffer_low_latency = cls._build_buffer( + group, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=max(num_experts // group.size(), Buffer.num_sms // 2), + explicitly_destroy=cls._explicitly_destroy, + ) + return cls._buffer_low_latency + + +def get_buffer_common( + group: dist.ProcessGroup, + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_experts: int, + hidden_bytes: int, +): + return DeepEPBuffer.get_buffer_common(group, num_max_dispatch_tokens_per_rank, hidden, num_experts, hidden_bytes) def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): @@ -58,19 +273,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling """ - global _buffer_normal - num_nvl_bytes, num_rdma_bytes = 0, 0 - for config in ( - Buffer.get_dispatch_config(group.size()), - Buffer.get_combine_config(group.size()), - ): - num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) - num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) - - if (_buffer_normal is None or _buffer_normal.group != group or _buffer_normal.num_nvl_bytes < num_nvl_bytes - or _buffer_normal.num_rdma_bytes < num_rdma_bytes): - _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes) - return _buffer_normal + return DeepEPBuffer.get_buffer_normal(group, hidden_bytes) def get_buffer_low_latency( @@ -84,21 +287,7 @@ def get_buffer_low_latency( https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding """ - global _buffer_low_latency - num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), - num_experts) - - if (_buffer_low_latency is None or _buffer_low_latency.group != group or not _buffer_low_latency.low_latency_mode - or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes): - assert num_experts % group.size( - ) == 0, f'num_experts: {num_experts} must be divisible by ep_size: {group.size()}' - _buffer_low_latency = Buffer( - group, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=max(num_experts // group.size(), Buffer.num_sms // 2), - ) - return _buffer_low_latency + return DeepEPBuffer.get_buffer_low_latency(group, num_max_dispatch_tokens_per_rank, hidden, num_experts) class DeepEPTokenDispatcher(TokenDispatcherImpl): @@ -346,6 +535,157 @@ def get_restored_hidden_states_by_experts( return hidden_states.to(input_dtype) +class DeepEPTokenDispatcherNormal(TokenDispatcherImpl): + """DeepEP normal-mode dispatcher used by LMDeploy EP MoE.""" + + def __init__( + self, + group: torch.distributed.ProcessGroup, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + num_max_dispatch_tokens_per_rank: int = 128, + expert_alignment: int = 128, + ): + self.group = group + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.params_bytes = params_dtype.itemsize + self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + self.handle = None + if not use_deepep: + raise ImportError('DeepEP is not installed. Please install DeepEP package from ' + 'https://github.com/deepseek-ai/deepep.') + self.buffer_normal = DeepEPBuffer.get_buffer_common( + self.group, + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.num_experts, + hidden_bytes=self.hidden_size * self.params_bytes, + ) + self.expert_alignment = expert_alignment + + def get_buffer(self): + return self.buffer_normal + + def dispatch( + self, + x, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + expert_list: list[int] = None, + previous_event=None, + ): + hidden_states = x[0] if isinstance(x, tuple) else x + self.hidden_shape = hidden_states.shape + topk_idx = topk_idx.to(torch.int64) + x, topk_idx, topk_weights, recv_tokens_per_expert, handle, event = self.dispatch_normal( + x, topk_idx, topk_weights, self.num_experts, previous_event) + + self.handle = handle + self.topk_idx = topk_idx + self.topk_weights = topk_weights + return x, topk_idx, topk_weights, recv_tokens_per_expert + + def dispatch_normal( + self, + x, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + previous_event=None, + ): + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = ( + self.get_buffer().get_dispatch_layout( + topk_idx, + num_experts, + previous_event=previous_event, + async_finish=False, + allocate_on_comm_stream=False, + )) + + recv_x, recv_topk_idx, recv_topk_weights, recv_tokens_per_expert, handle, event = self.get_buffer().dispatch( + x, + topk_idx=topk_idx, + topk_weights=topk_weights.to(torch.float32), + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=False, + allocate_on_comm_stream=False, + expert_alignment=self.expert_alignment, + ) + + return recv_x, recv_topk_idx, recv_topk_weights, recv_tokens_per_expert, handle, event + + def dispatch_normal_async(self, + x, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int | None = None, + previous_event=None, + async_finish=True): + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = ( + self.get_buffer().get_dispatch_layout( + topk_idx, + num_experts=self.num_experts if num_experts is None else num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=previous_event is not None and async_finish, + )) + + recv_x, recv_topk_idx, recv_topk_weights, recv_tokens_per_expert, handle, event = self.get_buffer().dispatch( + x, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=previous_event is not None and async_finish, + expert_alignment=self.expert_alignment, + ) + + return recv_x, recv_topk_idx, recv_topk_weights, recv_tokens_per_expert, handle, event + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, event = self.combine_normal(hidden_states, self.handle) + self.handle = None + return hidden_states.view(self.hidden_shape) + + def combine_normal(self, x: torch.Tensor, handle: tuple, previous_event=None): + combined_x, _, event = self.get_buffer().combine( + x, + handle, + async_finish=False, + previous_event=previous_event, + allocate_on_comm_stream=False, + ) + return combined_x, event + + def combine_normal_async(self, x: torch.Tensor, handle: tuple, previous_event=None, async_finish=True): + combined_x, _, event = self.get_buffer().combine( + x, + handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=previous_event is not None and async_finish, + ) + return combined_x, event + + def release(self): + self.handle = None + self.topk_idx = None + self.topk_weights = None + return True + + class DeepEPTokenDispatcherLowLatency(TokenDispatcherImpl): def __init__( @@ -355,6 +695,7 @@ def __init__( num_local_experts: int = None, hidden_size: int = None, params_dtype: torch.dtype = None, + num_max_dispatch_tokens_per_rank: int = 128, return_recv_hook: bool = False, ): if not use_deepep: @@ -366,14 +707,17 @@ def __init__( self.hidden_size = hidden_size self.params_bytes = params_dtype.itemsize self.handle = None - self.num_max_dispatch_tokens_per_rank = deep_ep_max_tokens_per_rank - self.buffer_low_latency = get_buffer_common(self.group, - self.num_max_dispatch_tokens_per_rank, - self.hidden_size, - self.num_experts, - hidden_bytes=self.hidden_size * self.params_bytes) + self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + self.buffer_low_latency = DeepEPBuffer.get_buffer_common(self.group, + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.num_experts, + hidden_bytes=self.hidden_size * self.params_bytes) self.return_recv_hook = return_recv_hook + def get_buffer(self): + return self.buffer_low_latency + def dispatch( self, hidden_states: torch.Tensor, @@ -382,10 +726,10 @@ def dispatch( num_experts: int, ) -> tuple[torch.Tensor, torch.Tensor]: topk_idx = topk_idx.to(torch.int64) - expected_m = (hidden_states.shape[0] * self.buffer_low_latency.group_size * topk_idx.shape[1] + + expected_m = (hidden_states.shape[0] * self.get_buffer().group_size * topk_idx.shape[1] + num_experts) // num_experts - packed_recv_hidden, masked_m, self.handle, event, hook = (self.buffer_low_latency.low_latency_dispatch( + packed_recv_hidden, masked_m, self.handle, event, hook = (self.get_buffer().low_latency_dispatch( hidden_states, topk_idx, self.num_max_dispatch_tokens_per_rank, @@ -395,6 +739,7 @@ def dispatch( return_recv_hook=self.return_recv_hook, )) hook() if self.return_recv_hook else event.current_stream_wait() + packed_recv_hidden = [DisposibleTensor(x) for x in packed_recv_hidden] return ( packed_recv_hidden, topk_idx, @@ -412,7 +757,7 @@ def dispatch_async( async_finish: bool = True, ): assert topk_idx.dtype == torch.int64 - recv_hidden_states, recv_expert_count, handle, event, hook = (self.buffer_low_latency.low_latency_dispatch( + recv_hidden_states, recv_expert_count, handle, event, hook = (self.get_buffer().low_latency_dispatch( hidden_states, topk_idx, self.num_max_dispatch_tokens_per_rank, @@ -421,6 +766,7 @@ def dispatch_async( async_finish=async_finish, return_recv_hook=not async_finish, )) + recv_hidden_states = [DisposibleTensor(x) for x in recv_hidden_states] return recv_hidden_states, recv_expert_count, handle, event, hook def combine( @@ -428,8 +774,8 @@ def combine( hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - combined_hidden_states, event, hook = (self.buffer_low_latency.low_latency_combine( + ) -> torch.Tensor: + combined_hidden_states, event, hook = (self.get_buffer().low_latency_combine( hidden_states, topk_idx, topk_weights.to(torch.float32), @@ -450,7 +796,7 @@ def combine_async( ) -> tuple[torch.Tensor, torch.Tensor | None]: assert topk_idx.dtype == torch.int64 assert topk_weights.dtype == torch.float32 - combined_hidden_states, event, hook = self.buffer_low_latency.low_latency_combine( + combined_hidden_states, event, hook = self.get_buffer().low_latency_combine( hidden_states, topk_idx, topk_weights, diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py index e574388f39..2a011ff46f 100644 --- a/lmdeploy/pytorch/backends/dlinfer/moe.py +++ b/lmdeploy/pytorch/backends/dlinfer/moe.py @@ -103,7 +103,8 @@ def build(top_k: int, ep_size: int = 1, ep_group: torch.distributed.ProcessGroup = None, layer_idx: int = 0, - out_dtype: torch.dtype = torch.bfloat16): + out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128): """Build from mlp.""" return DlinferFusedMoEImpl(top_k=top_k, num_experts=num_experts, diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py index 10a3c5e702..ec8487bdab 100644 --- a/lmdeploy/pytorch/backends/moe.py +++ b/lmdeploy/pytorch/backends/moe.py @@ -70,7 +70,8 @@ def build(top_k: int, ep_size: int = 1, ep_group: dist.ProcessGroup = None, layer_idx: int = 0, - out_dtype: torch.dtype = torch.bfloat16): + out_dtype: torch.dtype = torch.bfloat16, + num_max_dispatch_tokens_per_rank: int = 128): """Build from mlp.""" raise NotImplementedError @@ -166,6 +167,8 @@ def build(top_k: int, ep_size: int = 1, ep_group: dist.ProcessGroup = None, out_dtype: torch.dtype = torch.float16, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + num_max_dispatch_tokens_per_rank: int = 128, layer_idx: int = 0, custom_gateup_act: bool = False): """Build from mlp.""" diff --git a/lmdeploy/pytorch/check_env/dist.py b/lmdeploy/pytorch/check_env/dist.py index eb6e48e882..703ea0358d 100644 --- a/lmdeploy/pytorch/check_env/dist.py +++ b/lmdeploy/pytorch/check_env/dist.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from lmdeploy.pytorch.config import DistConfig -from lmdeploy.utils import is_dlblas_installed +from lmdeploy.utils import is_deep_ep_installed, is_deep_gemm_installed from .base import BaseChecker @@ -41,9 +41,10 @@ def check(self): f'Get distributed_executor_backend={distributed_executor_backend}.') if self.ep > 1: - if self.device_type == 'cuda' and not is_dlblas_installed(): + if self.device_type == 'cuda' and (not is_deep_ep_installed() or not is_deep_gemm_installed()): self.log_and_exit(mod_name='Dist', - message='ep>1 requires install dlblas(https://github.com/DeepLink-org/dlBLAS).') + message='ep>1 requires install DeepEP(https://github.com/deepseek-ai/DeepEP) ' + 'and DeepGEMM(https://github.com/deepseek-ai/DeepGEMM).') if self.ep % self.dp != 0: self.log_and_exit(mod_name='Dist', message=f'ep>1 requires ep % dp == 0. Get dp={self.dp} and ep={self.ep}.') diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f9c28e4159..8b27d3dcc8 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1065,7 +1065,8 @@ def _build_model(self): quant_config=self.model_config.quant_config, fp32_lm_head=self.model_config.fp32_lm_head, tie_word_embeddings=self.model_config.tie_word_embeddings, - num_spec_tokens=self.spec_agent.num_spec_tokens) + num_spec_tokens=self.spec_agent.num_spec_tokens, + max_batch_size=self.cache_config.max_batches) patched_model = build_patched_model(self.model_config, device=device, build_model_ctx=build_model_ctx) logger.debug(msg_with_rank(rank, 'loading weights.')) if not self.misc_config.empty_init: diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py index 170ba7fa75..39ecbf9a76 100644 --- a/lmdeploy/pytorch/envs.py +++ b/lmdeploy/pytorch/envs.py @@ -145,16 +145,17 @@ def _patched_get_env( os.getenv('HCCL_OP_EXPANSION_MODE', None) os.getenv('HCCL_IF_IP', None) - # dlblas - # we don't need to read this, it would be passed to ray workers - # If Ray is launched from outside, it may fail to access the environment variables. - deep_ep_max_tokens_per_rank = env_to_int('DEEPEP_MAX_TOKENS_PER_RANK', 128) + # deepep os.getenv('DEEPEP_ENABLE_MNNVL', None) os.getenv('DEEPEP_MODE', 'auto') - - # deepep deep_ep_buffer_num_sms = env_to_int('DEEPEP_BUFFER_NUM_SMS', 20) + # eplb + eplb_num_groups = env_to_int('LMDEPLOY_EPLB_NUM_GROUPS', 4) + eplb_experts_statistic_file = os.getenv('LMDEPLOY_EPLB_EXPERTS_STATISTIC_FILE', None) + eplb_ranks_per_node = env_to_int('LMDEPLOY_EPLB_RANKS_PER_NODE', 8) + eplb_num_redundant_experts = env_to_int('LMDEPLOY_EPLB_NUM_REDUNDANT_EXPERTS', 32) + # deepgemm os.getenv('DG_JIT_DEBUG', '0') os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', '0') diff --git a/lmdeploy/pytorch/kernels/cuda/activation.py b/lmdeploy/pytorch/kernels/cuda/activation.py index a45dad18c9..d8dd77526c 100644 --- a/lmdeploy/pytorch/kernels/cuda/activation.py +++ b/lmdeploy/pytorch/kernels/cuda/activation.py @@ -197,3 +197,22 @@ def silu_and_mul_moe_ep(gate_up: torch.Tensor, mask_m: torch.Tensor, out: torch. num_stages=num_stages) return out + + +def silu_and_mul_masked_post_quant_fwd(input: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor, + quant_group_size: int, masked_m: torch.Tensor): + """Apply masked MoE SiLU-and-mul, then quantize to the preallocated FP8 + output.""" + assert input.is_contiguous() + assert output.is_contiguous() + assert input.dim() == 3 + assert input.shape[0] == masked_m.shape[0] + assert input.shape[-1] % 2 == 0 + size_n = input.shape[-1] // 2 + assert size_n % quant_group_size == 0 + activated = silu_and_mul_moe_ep(input, masked_m) + from .blocked_gemm_fp8 import _quant_fp8_launcher + _quant_fp8_launcher(activated.reshape(-1, size_n), + quant_group_size, + output.reshape(-1, size_n), + output_scale.reshape(-1, size_n // quant_group_size)) diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py index e6717b5ef4..6b81eb0a5f 100644 --- a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py +++ b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py @@ -173,6 +173,22 @@ def quant_fp8(A: Tensor, return _quant_fp8_launcher(A, group_size, out, scales, scale_fmt=scale_fmt) +def per_token_group_quant_fp8(A: Tensor, + group_size: int, + dtype: torch.dtype = torch.float8_e4m3fn, + scale_fmt: str | None = None): + """Per-token-group FP8 quantization for tensors with arbitrary leading + dims.""" + assert A.dim() >= 2 + assert A.stride(-1) == 1, f'{A} groups must be contiguous' + M = A.numel() // A.size(-1) + K = A.size(-1) + assert K % group_size == 0 + out = torch.empty_like(A, dtype=dtype) + scales = A.new_empty(*A.shape[:-1], K // group_size, dtype=torch.float32) + return _quant_fp8_launcher(A.view(M, K), group_size, out.view(M, K), scales.view(M, K // group_size), scale_fmt) + + def quant_fp8_tma(A: Tensor, group_size: int, dtype: torch.dtype = torch.float8_e4m3fn, diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe_ep_fp8.py b/lmdeploy/pytorch/kernels/cuda/fused_moe_ep_fp8.py new file mode 100644 index 0000000000..dfcda0db0e --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe_ep_fp8.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from dlblas: https://github.com/DeepLink-org/DLBlas +import torch +import triton +import triton.language as tl + +from .activation import silu_and_mul +from .blocked_gemm_fp8 import per_token_group_quant_fp8 +from .fused_moe_ep import ep_gather + + +@triton.jit +def _fwd_kernel_ep_scatter_fp8_step1( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_EXPERT_NUM: tl.constexpr, +): + cur_expert = tl.program_id(0) + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) + tokens_per_expert = tl.load(num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0) + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) + cur_expert_start = tl.load(expert_start_loc + cur_expert) + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + m_indices_start_ptr = m_indices + cur_expert_start + off_expert = tl.arange(0, BLOCK_E) + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + tl.store(m_indices_start_ptr + start_m + off_expert, cur_expert) + + +@triton.jit +def _fwd_kernel_ep_scatter_fp8_step2( + total_token_num, + expert_start_loc, + recv_x, + recv_x_stride0, + recv_x_stride1, + recv_x_scale, + recv_x_scale_stride0, + recv_x_scale_stride1, + recv_topk, + recv_topk_stride0, + recv_topk_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + output_tensor_scale, + output_tensor_scale_stride0, + output_tensor_scale_stride1, + output_index, + output_index_stride0, + output_index_stride1, + topk_num: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + HIDDEN_SIZE_PAD: tl.constexpr, + SCALE_HIDDEN_SIZE: tl.constexpr, + SCALE_HIDDEN_SIZE_PAD: tl.constexpr, +): + start_token_id = tl.program_id(0) + grid_num = tl.num_programs(0) + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE + offset_scale = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) + mask_scale = offset_scale < SCALE_HIDDEN_SIZE + for token_id in range(start_token_id, total_token_num, grid_num): + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) + to_copy_scale = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + offset_scale, mask=mask_scale) + for topk_index in tl.range(0, topk_num, 1, num_stages=4): + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) + if expert_id >= 0: + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) + dest_token_index = dest_token_index.to(tl.int64) + tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index) + output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0 + output_tensor_scale_ptr = output_tensor_scale + dest_token_index * output_tensor_scale_stride0 + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) + tl.store(output_tensor_scale_ptr + offset_scale, to_copy_scale, mask=mask_scale) + + +@torch.no_grad() +def ep_scatter_fp8( + recv_x: torch.Tensor, + recv_x_scale: torch.Tensor, + recv_topk: torch.Tensor, + num_recv_tokens_per_expert: torch.Tensor, + expert_start_loc: torch.Tensor, + output_tensor: torch.Tensor, + output_tensor_scale: torch.Tensor, + m_indices: torch.Tensor, + output_index: torch.Tensor, +): + block_e = 128 + num_warps = 8 + num_experts = num_recv_tokens_per_expert.shape[0] + hidden_size = recv_x.shape[1] + scale_hidden_size = recv_x_scale.shape[1] + assert m_indices.shape[0] % block_e == 0 + _fwd_kernel_ep_scatter_fp8_step1[(num_experts, )]( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts=num_experts, + num_warps=num_warps, + BLOCK_E=block_e, + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), + ) + grid = min(recv_topk.shape[0], 1024 * 8) + _fwd_kernel_ep_scatter_fp8_step2[(grid, )]( + recv_topk.shape[0], + expert_start_loc, + recv_x, + recv_x.stride(0), + recv_x.stride(1), + recv_x_scale, + recv_x_scale.stride(0), + recv_x_scale.stride(1), + recv_topk, + recv_topk.stride(0), + recv_topk.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor_scale, + output_tensor_scale.stride(0), + output_tensor_scale.stride(1), + output_index, + output_index.stride(0), + output_index.stride(1), + topk_num=recv_topk.shape[1], + num_warps=num_warps, + HIDDEN_SIZE=hidden_size, + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), + SCALE_HIDDEN_SIZE=scale_hidden_size, + SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size), + ) + + +def _deepgemm_grouped_fp8_nt_contiguous(input_tuple, w_tuple, out: torch.Tensor, m_indices: torch.Tensor): + from lmdeploy.pytorch.third_party import deep_gemm + return deep_gemm.m_grouped_fp8_gemm_nt_contiguous(input_tuple, w_tuple, out, m_indices) + + +def fused_moe_v3_fp8( + hidden_states_fp8: tuple[torch.Tensor, torch.Tensor], + topk_idx, + topk_weights, + w13_weight_fp8: tuple[torch.Tensor, torch.Tensor], + w2_weight_fp8: tuple[torch.Tensor, torch.Tensor], + num_recv_tokens_per_expert: list[int] | None, +): + hidden_states_fp8, hidden_states_scale = hidden_states_fp8 + if num_recv_tokens_per_expert is None: + return hidden_states_fp8.to(torch.bfloat16) + all_tokens = sum(num_recv_tokens_per_expert) + if all_tokens <= 0: + return hidden_states_fp8.to(torch.bfloat16) + from lmdeploy.pytorch.third_party.deep_gemm import get_mn_major_tma_aligned_tensor + m, k = hidden_states_fp8.size() + n = w13_weight_fp8[0].size(1) + block_size = k // hidden_states_scale.size(1) + gather_out = torch.empty_like(hidden_states_fp8, device=hidden_states_fp8.device, dtype=torch.bfloat16) + input_tensor = torch.empty((all_tokens, k), device=hidden_states_fp8.device, dtype=hidden_states_fp8.dtype) + input_tensor_scale = torch.empty((all_tokens, k // block_size), + device=hidden_states_fp8.device, + dtype=torch.float32) + m_indices = torch.empty(all_tokens, device=hidden_states_fp8.device, dtype=torch.int32) + output_index = torch.empty_like(topk_idx) + num_recv_tokens_per_expert_gpu = torch.tensor(num_recv_tokens_per_expert, + dtype=torch.int32, + pin_memory=True, + device='cpu').cuda(non_blocking=True) + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) + ep_scatter_fp8(hidden_states_fp8, hidden_states_scale, topk_idx, num_recv_tokens_per_expert_gpu, expert_start_loc, + input_tensor, input_tensor_scale, m_indices, output_index) + del hidden_states_fp8 + + gateup_output = torch.empty((all_tokens, n), device=gather_out.device, dtype=torch.bfloat16) + input_tensor_scale = get_mn_major_tma_aligned_tensor(input_tensor_scale) + _deepgemm_grouped_fp8_nt_contiguous((input_tensor, input_tensor_scale), w13_weight_fp8, gateup_output, m_indices) + + down_input = torch.empty((all_tokens, n // 2), device=gateup_output.device, dtype=torch.bfloat16) + silu_and_mul(gateup_output.view(-1, n), down_input) + del gateup_output + down_input_fp8, down_input_scale = per_token_group_quant_fp8(down_input, block_size) + down_input_scale = get_mn_major_tma_aligned_tensor(down_input_scale) + down_output = torch.empty((all_tokens, k), device=gather_out.device, dtype=torch.bfloat16) + _deepgemm_grouped_fp8_nt_contiguous((down_input_fp8, down_input_scale), w2_weight_fp8, down_output, m_indices) + ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) + return gather_out diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 270bc34b0e..bb66e53f91 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -461,6 +461,14 @@ class BuildModelContext: fp32_lm_head: bool = False tie_word_embeddings: bool = False num_spec_tokens: int = 0 + max_batch_size: int = 0 + + @property + def deep_ep_max_tokens_per_rank(self) -> int: + """Infer DeepEP low-latency max dispatch tokens per rank.""" + if self.max_batch_size <= 0: + return 128 + return self.max_batch_size * (1 + self.num_spec_tokens) class StepContextManager(CtxMgrBase[StepContext]): diff --git a/lmdeploy/pytorch/nn/eplb.py b/lmdeploy/pytorch/nn/eplb.py index ee87546a3b..b9a82da338 100644 --- a/lmdeploy/pytorch/nn/eplb.py +++ b/lmdeploy/pytorch/nn/eplb.py @@ -1,5 +1,305 @@ # Copyright (c) OpenMMLab. All rights reserved. +import json +import random +from dataclasses import dataclass + import torch +import torch.nn.functional as F + +from lmdeploy.pytorch.envs import ( + eplb_experts_statistic_file, + eplb_num_groups, + eplb_num_redundant_experts, + eplb_ranks_per_node, +) + + +def balanced_packing(weight: torch.Tensor, num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: + """Pack expert groups with approximately balanced weights.""" + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64) + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min((idx for idx in range(num_packs) if pack_items[idx] < groups_per_pack), + key=pack_weights.__getitem__) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts(weight: torch.Tensor, num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Create redundant physical experts for the heaviest logical experts.""" + num_layers, num_log = weight.shape + assert num_phy >= num_log + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(num_layers, 1) + rank = torch.zeros(num_layers, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(num_layers, num_log, dtype=torch.int64, device=device) + arange_layers = torch.arange(num_layers, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arange_layers, redundant_indices] + logcnt[arange_layers, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def _inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) + return inv + + +def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: int, num_groups: int, num_nodes: int, + num_gpus: int): + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) + mlog2log = _inverse(log2mlog) + tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = _inverse(phy2pphy) + pphy2mlog = phy2mlog.gather(-1, pphy2phy) + node_offsets = torch.arange(0, + num_logical_experts, + num_logical_experts // num_nodes, + dtype=torch.int64, + device=weight.device) + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + node_offsets.view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, + num_gpus: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_layers, num_logical_experts = weight.shape + weight = weight.float() + if num_groups % num_nodes == 0: + phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, num_groups, num_nodes, num_gpus) + else: + phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas) + maxlogcnt = logcnt.max().item() + log2phy = torch.full((num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device) + log2phy.view(num_layers, -1).scatter_( + -1, phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1)) + return phy2log, log2phy, logcnt + + +def logical_to_all_physical_raw(logical_to_all_physical_map: torch.Tensor, layer_id: int, + logical_expert_id: int) -> list[int]: + return [ + physical_expert_id for physical_expert_id in logical_to_all_physical_map[layer_id, + logical_expert_id].tolist() + if physical_expert_id != -1 + ] + + +def _compute_gpu_id_of_physical_expert(physical_expert_id: int, num_local_physical_experts: int) -> int: + return physical_expert_id // num_local_physical_experts + + +def _fair_choices(arr: list[int], k: int, r: random.Random) -> list[int]: + quotient, remainder = divmod(k, len(arr)) + res = arr * quotient + r.sample(arr, k=remainder) + r.shuffle(res) + return res + + +def compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map: torch.Tensor, + num_gpus: int, + num_physical_experts: int, + seed: int = 42, +): + r = random.Random(seed) + num_local_physical_experts = num_physical_experts // num_gpus + num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape + dtype = logical_to_all_physical_map.dtype + logical_to_rank_dispatch_physical_map = torch.full( + size=(num_gpus, num_layers, num_logical_experts), + fill_value=-1, + dtype=dtype, + device=logical_to_all_physical_map.device, + ) + + for layer_id in range(num_layers): + for logical_expert_id in range(num_logical_experts): + candidates = logical_to_all_physical_raw(logical_to_all_physical_map, layer_id, logical_expert_id) + output_partial = logical_to_rank_dispatch_physical_map[:, layer_id, logical_expert_id] + for gpu_id in range(num_gpus): + same_gpu = [ + physical_expert_id for physical_expert_id in candidates + if _compute_gpu_id_of_physical_expert(physical_expert_id, num_local_physical_experts) == gpu_id + ] + if same_gpu: + output_partial[gpu_id] = same_gpu[0] + + num_remain = torch.sum(output_partial == -1).item() + output_partial[output_partial == -1] = torch.tensor(_fair_choices(candidates, k=num_remain, r=r), + dtype=dtype, + device=logical_to_all_physical_map.device) + + assert torch.all(logical_to_rank_dispatch_physical_map != -1) + return logical_to_rank_dispatch_physical_map + + +@dataclass +class EPLBMetadata: + physical_to_logical_map: torch.Tensor + logical_to_all_physical_map: torch.Tensor + logical_to_all_physical_map_num_valid: torch.Tensor + logical_to_rank_dispatch_physical_map: torch.Tensor + + def num_physical_experts(self) -> int: + return self.physical_to_logical_map.shape[1] + + def __post_init__(self): + num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape + num_layers_1, num_logical_experts_0, num_physical_experts_1 = self.logical_to_all_physical_map.shape + num_layers_2, num_logical_experts_1 = self.logical_to_all_physical_map_num_valid.shape + _, num_layers_3, num_logical_experts_2 = self.logical_to_rank_dispatch_physical_map.shape + assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 + assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 + assert num_physical_experts_0 == num_physical_experts_1 + + @staticmethod + def _init_raw(ep_size: int, physical_to_logical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor): + _, num_physical_experts = physical_to_logical_map.shape + logical_to_all_physical_map_padded = F.pad( + logical_to_all_physical_map, + (0, num_physical_experts - logical_to_all_physical_map.shape[-1]), + value=-1, + ) + logical_to_all_physical_map_num_valid = torch.count_nonzero(logical_to_all_physical_map != -1, dim=-1) + return EPLBMetadata( + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map_padded, + logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, + logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map, + num_gpus=ep_size, + num_physical_experts=num_physical_experts, + ), + ) + + @staticmethod + def init(ep_size: int, num_routed_experts: int, num_hidden_layers: int): + num_groups = eplb_num_groups + weight_path = eplb_experts_statistic_file + ranks_per_node = eplb_ranks_per_node + num_redundant_experts = eplb_num_redundant_experts + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if weight_path is None: + experts_statistic = torch.arange(num_routed_experts, dtype=torch.int32, + device=device).flip(dims=(0, )).expand(num_hidden_layers, -1) + else: + try: + with open(weight_path) as f: + experts_statistic = torch.tensor(json.load(f), dtype=torch.float32, device=device) + except Exception as exc: + raise RuntimeError(f'Load eplb experts statistic data failed, path: {weight_path}') from exc + target_shape = torch.Size([num_hidden_layers, num_routed_experts]) + assert experts_statistic.shape == target_shape, f'Shape of {weight_path} must be {target_shape}' + + num_nodes = 1 if ep_size < ranks_per_node else ep_size // ranks_per_node + num_physical_experts = num_routed_experts + num_redundant_experts + physical_to_logical_map, logical_to_all_physical_map, _ = rebalance_experts( + weight=experts_statistic, + num_replicas=num_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=ep_size, + ) + return EPLBMetadata._init_raw( + ep_size=ep_size, + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + ) + + +_global_eplb_metadata: EPLBMetadata | None = None + + +def init_global_eplb_metadata(ep_size: int, num_routed_experts: int, num_hidden_layers: int): + global _global_eplb_metadata + if _global_eplb_metadata is not None: + raise RuntimeError('Global EPLB metadata has already been initialized.') + _global_eplb_metadata = EPLBMetadata.init(ep_size=ep_size, + num_routed_experts=num_routed_experts, + num_hidden_layers=num_hidden_layers) + + +def get_global_eplb_metadata(): + global _global_eplb_metadata + if _global_eplb_metadata is None: + raise RuntimeError('Global EPLB metadata has not been initialized.') + return _global_eplb_metadata + + +def get_eplb_phy2log_metadata_by_layer(layer_idx: int): + return get_global_eplb_metadata().physical_to_logical_map[layer_idx] + + +@dataclass +class _EPLBDispatchInfo: + partial_logical_to_rank_dispatch_physical_map: torch.Tensor + partial_logical_to_all_physical_map: torch.Tensor + partial_logical_to_all_physical_map_num_valid: torch.Tensor + + @classmethod + def init_new(cls, ep_rank: int, layer_idx: int): + eplb_metadata = get_global_eplb_metadata() + return cls( + partial_logical_to_rank_dispatch_physical_map=eplb_metadata.logical_to_rank_dispatch_physical_map[ + ep_rank, layer_idx, :], + partial_logical_to_all_physical_map=eplb_metadata.logical_to_all_physical_map[layer_idx, :], + partial_logical_to_all_physical_map_num_valid=eplb_metadata.logical_to_all_physical_map_num_valid[ + layer_idx, :], + ) + + +def topk_ids_logical_to_physical(topk_ids: torch.Tensor, info: _EPLBDispatchInfo | None) -> torch.Tensor: + if info is None: + return topk_ids + original_shape = topk_ids.shape + topk_ids = topk_ids.flatten() + chosen_dispatch_index = (torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=topk_ids.device) % + info.partial_logical_to_all_physical_map_num_valid[topk_ids]) + topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + return topk_ids.view(original_shape) class EPLBDispatchInfo: @@ -9,26 +309,22 @@ def __init__(self, info) -> None: class EPLBManager: - eplb = None @classmethod def init_global_eplb_metadata(cls, ep_size: int, num_routed_experts: int, num_hidden_layers: int): assert ep_size > 1, 'eplb requires ep_size > 1' - from dlblas.layers.moe import eplb - EPLBManager.eplb = eplb - eplb.init_global_eplb_metadata(ep_size=ep_size, - num_routed_experts=num_routed_experts, - num_hidden_layers=num_hidden_layers) + init_global_eplb_metadata(ep_size=ep_size, + num_routed_experts=num_routed_experts, + num_hidden_layers=num_hidden_layers) @classmethod def num_physical_experts(cls) -> int: - return EPLBManager.eplb.get_global_eplb_metadata().num_physical_experts() + return get_global_eplb_metadata().num_physical_experts() @classmethod def topk_ids_logical_to_physical(cls, topk_ids: torch.Tensor, eplb_dispatch_info: EPLBDispatchInfo): - return EPLBManager.eplb.topk_ids_logical_to_physical(topk_ids=topk_ids, info=eplb_dispatch_info.info) + return topk_ids_logical_to_physical(topk_ids=topk_ids, info=eplb_dispatch_info.info) @classmethod def get_dispatch_info(cls, ep_rank, layer_idx) -> EPLBDispatchInfo: - info = EPLBManager.eplb.EPLBDispatchInfo.init_new(ep_rank=ep_rank, layer_idx=layer_idx) - return EPLBDispatchInfo(info) + return EPLBDispatchInfo(_EPLBDispatchInfo.init_new(ep_rank=ep_rank, layer_idx=layer_idx)) diff --git a/lmdeploy/pytorch/nn/moe/blocked_fp8.py b/lmdeploy/pytorch/nn/moe/blocked_fp8.py index 807833212e..74f5e26d9e 100644 --- a/lmdeploy/pytorch/nn/moe/blocked_fp8.py +++ b/lmdeploy/pytorch/nn/moe/blocked_fp8.py @@ -5,6 +5,7 @@ from lmdeploy.pytorch.backends import OpType, get_backend from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank +from lmdeploy.pytorch.models.patch import get_build_model_context from ..quant_utils import quant_blocked_fp8 from ..utils import div_up @@ -173,6 +174,7 @@ def __init__(self, dist_ctx = get_dist_manager().current_context() self.ep_size, rank = get_ep_world_rank() impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEBlockedF8) + deep_ep_max_tokens_per_rank = get_build_model_context().deep_ep_max_tokens_per_rank self.impl = impl_builder.build(top_k, num_experts, hidden_dim, @@ -181,6 +183,8 @@ def __init__(self, ep_size=self.ep_size, ep_group=dist_ctx.ep_gpu_group, out_dtype=dtype, + fp8_dtype=fp8_dtype, + num_max_dispatch_tokens_per_rank=deep_ep_max_tokens_per_rank, layer_idx=layer_idx, custom_gateup_act=act_func is not None) self.impl.set_scale_fmt(scale_fmt) @@ -247,7 +251,9 @@ def before_dispatch(self, state: DispatchInputs): fusedmoe = self.fusedmoe_build(low_latency_mode=False) state['fusedmoe'] = fusedmoe if hasattr(fusedmoe, 'per_token_group_quant_fp8'): - state['hidden_states'] = fusedmoe.per_token_group_quant_fp8(state['hidden_states']) + state['hidden_states'] = fusedmoe.per_token_group_quant_fp8(state['hidden_states'], + dtype=self.gate_up.weight.dtype, + scale_fmt=self.scale_fmt) previous_event = fusedmoe.capture() state['previous_event'] = previous_event return state diff --git a/lmdeploy/pytorch/nn/moe/default.py b/lmdeploy/pytorch/nn/moe/default.py index efb5f4483c..8dc668355b 100644 --- a/lmdeploy/pytorch/nn/moe/default.py +++ b/lmdeploy/pytorch/nn/moe/default.py @@ -7,6 +7,7 @@ from lmdeploy.pytorch.backends import OpType, get_backend from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank +from lmdeploy.pytorch.models.patch import get_build_model_context from .base import DispatchInputs, FusedMoEBase, MoeType, moe_gather_inputs, moe_reduce, update_dims @@ -136,6 +137,7 @@ def __init__(self, dist_ctx = get_dist_manager().current_context() self.ep_size, rank = get_ep_world_rank() impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) + deep_ep_max_tokens_per_rank = get_build_model_context().deep_ep_max_tokens_per_rank self.impl = impl_builder.build( top_k, num_experts, @@ -144,6 +146,7 @@ def __init__(self, ep_size=self.ep_size, ep_group=dist_ctx.ep_gpu_group, layer_idx=layer_idx, + num_max_dispatch_tokens_per_rank=deep_ep_max_tokens_per_rank, ) # create weights diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index c16e2497f5..e79bf8c54e 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -506,13 +506,20 @@ def serialize_state_dict(state_dict: dict) -> str: return pybase64.b64encode(buf.read()).decode('utf-8') -def is_dlblas_installed(): - is_dlblas_installed = True +def is_deep_ep_installed(): try: - import dlblas # noqa: F401 + import deep_ep # noqa: F401 except Exception: - is_dlblas_installed = False - return is_dlblas_installed + return False + return True + + +def is_deep_gemm_installed(): + try: + import deep_gemm # noqa: F401 + except Exception: + return False + return True # from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/weight_sync/tensor_bucket.py diff --git a/tests/pytorch/nn/test_moe_deepep.py b/tests/pytorch/nn/test_moe_deepep.py new file mode 100644 index 0000000000..455fe5c8dd --- /dev/null +++ b/tests/pytorch/nn/test_moe_deepep.py @@ -0,0 +1,463 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import builtins +import importlib +from pathlib import Path + +import pytest +import torch + + +def test_dist_checker_requires_deepep_and_deepgemm(monkeypatch): + from lmdeploy.pytorch.check_env import dist as dist_check + + failures = [] + monkeypatch.setattr(dist_check, 'is_deep_ep_installed', lambda: False) + monkeypatch.setattr(dist_check, 'is_deep_gemm_installed', lambda: True) + + checker = dist_check.DistChecker(tp=1, dp=1, ep=2, distributed_executor_backend='mp', device_type='cuda') + monkeypatch.setattr(checker, 'log_and_exit', lambda **kwargs: failures.append(kwargs)) + + checker.check() + + assert failures + assert 'DeepEP' in failures[0]['message'] + assert 'DeepGEMM' in failures[0]['message'] + assert 'dl' + 'blas' not in failures[0]['message'].lower() + + +def test_eplb_metadata_and_dispatch_mapping(monkeypatch): + from lmdeploy.pytorch.nn import eplb + + physical_to_logical = torch.tensor([[0, 1, 1]]) + logical_to_all_physical = torch.tensor([[[0, -1], [1, 2]]]) + metadata = eplb.EPLBMetadata._init_raw( + ep_size=1, + physical_to_logical_map=physical_to_logical, + logical_to_all_physical_map=logical_to_all_physical, + ) + monkeypatch.setattr(eplb, '_global_eplb_metadata', metadata) + + assert eplb.EPLBManager.num_physical_experts() == 3 + assert eplb.get_eplb_phy2log_metadata_by_layer(0).tolist() == [0, 1, 1] + + info = eplb.EPLBManager.get_dispatch_info(ep_rank=0, layer_idx=0) + topk_ids = torch.tensor([[0, 1, 1]]) + physical = eplb.EPLBManager.topk_ids_logical_to_physical(topk_ids, info) + + assert physical[0, 0].item() == 0 + assert physical[0, 1].item() in (1, 2) + assert physical[0, 2].item() in (1, 2) + + +def test_deepep_buffer_uses_internal_default_token_limit(monkeypatch): + from lmdeploy.pytorch.backends.cuda import token_dispatcher as td + + class FakeConfig: + + def get_nvl_buffer_size_hint(self, hidden_bytes, group_size): + return hidden_bytes + group_size + + def get_rdma_buffer_size_hint(self, hidden_bytes, group_size): + return hidden_bytes + group_size + 1 + + class FakeBuffer: + num_sms = 20 + low_latency_size_hint_args = None + build_count = 0 + + @staticmethod + def get_dispatch_config(group_size): + return FakeConfig() + + @staticmethod + def get_combine_config(group_size): + return FakeConfig() + + @staticmethod + def get_low_latency_rdma_size_hint(*args): + FakeBuffer.low_latency_size_hint_args = args + return 4096 + + def __init__(self, group, num_nvl_bytes=0, num_rdma_bytes=0, low_latency_mode=False, **kwargs): + FakeBuffer.build_count += 1 + self.group = group + self.num_nvl_bytes = num_nvl_bytes + self.num_rdma_bytes = num_rdma_bytes + self.low_latency_mode = low_latency_mode + self.kwargs = kwargs + self.destroyed = False + self.group_size = group.size() + + def set_num_sms(self, num_sms): + self.num_sms = num_sms + + def destroy(self): + self.destroyed = True + + def clean_low_latency_buffer(self, *args): + self.clean_args = args + + class FakeGroup: + + def size(self): + return 2 + + monkeypatch.setenv('DEEPEP_MAX_TOKENS' + '_PER_RANK', '999') + monkeypatch.setenv('DEEPEP_BUFFER_NUM_SMS', '13') + monkeypatch.setattr(td, 'Buffer', FakeBuffer) + monkeypatch.setattr(td, 'use_deepep', True) + td.DeepEPBuffer._buffer_common = None + td.DeepEPBuffer._buffer_normal = None + td.DeepEPBuffer._buffer_low_latency = None + td.DeepEPBuffer._explicitly_destroy = False + td.DeepEPBuffer._deepep_sms = 20 + td.DeepEPBuffer._num_max_dispatch_tokens_per_rank = 128 + FakeBuffer.build_count = 0 + + assert td.DeepEPBuffer.set_explicitly_destroy() is True + buffer = td.DeepEPBuffer.get_buffer_common(FakeGroup(), 128, hidden=16, num_experts=4, hidden_bytes=32) + reused_buffer = td.DeepEPBuffer.get_buffer_common(FakeGroup(), 256, hidden=32, num_experts=4, hidden_bytes=1024) + + assert FakeBuffer.low_latency_size_hint_args[0] == 128 + assert FakeBuffer.build_count == 1 + assert reused_buffer is buffer + assert buffer.kwargs['explicitly_destroy'] is True + assert buffer.kwargs['num_qps_per_rank'] == 13 + assert buffer.num_sms == 13 + assert td.DeepEPBuffer.destroy() is True + assert buffer.destroyed is True + assert td.DeepEPBuffer.destroy() is False + + +def test_disposible_tensor_dispose_is_best_effort_with_extra_refs(): + from lmdeploy.pytorch.backends.cuda.token_dispatcher import DisposibleTensor + + tensor = torch.empty(1) + wrapped = DisposibleTensor(tensor) + extra_refs = [tensor] + + wrapped.dispose() + + assert wrapped.value is tensor + assert extra_refs[0] is tensor + + +def test_low_latency_dispatcher_accepts_explicit_token_limit(monkeypatch): + from lmdeploy.pytorch.backends.cuda import token_dispatcher as td + + class FakeConfig: + + def get_nvl_buffer_size_hint(self, hidden_bytes, group_size): + return hidden_bytes + group_size + + def get_rdma_buffer_size_hint(self, hidden_bytes, group_size): + return hidden_bytes + group_size + 1 + + class FakeBuffer: + num_sms = 20 + low_latency_size_hint_args = None + + @staticmethod + def get_dispatch_config(group_size): + return FakeConfig() + + @staticmethod + def get_combine_config(group_size): + return FakeConfig() + + @staticmethod + def get_low_latency_rdma_size_hint(*args): + FakeBuffer.low_latency_size_hint_args = args + return 4096 + + def __init__(self, group, *args, **kwargs): + self.group = group + self.num_nvl_bytes = kwargs.get('num_nvl_bytes', args[0] if len(args) > 0 else 0) + self.num_rdma_bytes = kwargs.get('num_rdma_bytes', args[1] if len(args) > 1 else 0) + self.low_latency_mode = kwargs.get('low_latency_mode', False) + self.group_size = group.size() + + def set_num_sms(self, num_sms): + self.num_sms = num_sms + + class FakeGroup: + + def size(self): + return 2 + + monkeypatch.setattr(td, 'Buffer', FakeBuffer) + monkeypatch.setattr(td, 'use_deepep', True) + td.DeepEPBuffer._buffer_common = None + td.DeepEPBuffer._num_max_dispatch_tokens_per_rank = 128 + + dispatcher = td.DeepEPTokenDispatcherLowLatency( + group=FakeGroup(), + num_experts=4, + num_local_experts=2, + hidden_size=16, + params_dtype=torch.bfloat16, + num_max_dispatch_tokens_per_rank=256, + ) + + assert dispatcher.num_max_dispatch_tokens_per_rank == 256 + assert FakeBuffer.low_latency_size_hint_args[0] == 256 + + +def test_normal_dispatcher_accepts_explicit_token_limit_for_common_buffer(monkeypatch): + from lmdeploy.pytorch.backends.cuda import token_dispatcher as td + + class FakeConfig: + + def get_nvl_buffer_size_hint(self, hidden_bytes, group_size): + return hidden_bytes + group_size + + def get_rdma_buffer_size_hint(self, hidden_bytes, group_size): + return hidden_bytes + group_size + 1 + + class FakeBuffer: + num_sms = 20 + low_latency_size_hint_args = None + + @staticmethod + def get_dispatch_config(group_size): + return FakeConfig() + + @staticmethod + def get_combine_config(group_size): + return FakeConfig() + + @staticmethod + def get_low_latency_rdma_size_hint(*args): + FakeBuffer.low_latency_size_hint_args = args + return 4096 + + def __init__(self, group, *args, **kwargs): + self.group = group + self.num_nvl_bytes = kwargs.get('num_nvl_bytes', args[0] if len(args) > 0 else 0) + self.num_rdma_bytes = kwargs.get('num_rdma_bytes', args[1] if len(args) > 1 else 0) + self.low_latency_mode = kwargs.get('low_latency_mode', False) + + def set_num_sms(self, num_sms): + self.num_sms = num_sms + + class FakeGroup: + + def size(self): + return 2 + + monkeypatch.setattr(td, 'Buffer', FakeBuffer) + monkeypatch.setattr(td, 'use_deepep', True) + td.DeepEPBuffer._buffer_common = None + td.DeepEPBuffer._num_max_dispatch_tokens_per_rank = 128 + + dispatcher = td.DeepEPTokenDispatcherNormal( + group=FakeGroup(), + num_experts=4, + num_local_experts=2, + hidden_size=16, + params_dtype=torch.bfloat16, + num_max_dispatch_tokens_per_rank=256, + ) + + assert dispatcher.num_max_dispatch_tokens_per_rank == 256 + assert FakeBuffer.low_latency_size_hint_args[0] == 256 + + +def test_deepep_token_limit_is_inferred_from_engine_max_batch_size(): + from lmdeploy.messages import PytorchEngineConfig + from lmdeploy.pytorch.engine.config_builder import ConfigBuilder + from lmdeploy.pytorch.model_inputs import BuildModelContext + + engine_config = PytorchEngineConfig(max_batch_size=32) + cache_config = ConfigBuilder.build_cache_config(engine_config) + build_ctx = BuildModelContext(max_batch_size=cache_config.max_batches, num_spec_tokens=3) + + assert cache_config.max_batches == 32 + assert build_ctx.deep_ep_max_tokens_per_rank == 128 + + +def test_all_fused_moe_builders_accept_deepep_token_limit(): + def build_args(module_path, class_name): + tree = ast.parse((Path(__file__).parents[3] / module_path).read_text()) + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + for item in node.body: + if isinstance(item, ast.FunctionDef) and item.name == 'build': + return [arg.arg for arg in item.args.args] + raise AssertionError(f'{class_name}.build not found') + + assert 'num_max_dispatch_tokens_per_rank' in build_args('lmdeploy/pytorch/backends/cuda/moe/default.py', + 'TritonFusedMoEBuilder') + assert 'num_max_dispatch_tokens_per_rank' in build_args('lmdeploy/pytorch/backends/dlinfer/moe.py', + 'DlinferFusedMoEBuilder') + + +def test_eplb_env_vars_are_lmdeploy_prefixed(): + envs_text = (Path(__file__).parents[3] / 'lmdeploy/pytorch/envs.py').read_text() + + assert "'LMDEPLOY_EPLB_NUM_GROUPS'" in envs_text + assert "'LMDEPLOY_EPLB_EXPERTS_STATISTIC_FILE'" in envs_text + assert "'LMDEPLOY_EPLB_RANKS_PER_NODE'" in envs_text + assert "'LMDEPLOY_EPLB_NUM_REDUNDANT_EXPERTS'" in envs_text + + old_env_vars = [ + 'EPLB' + '_NUM_GROUPS', + 'EPLB' + '_EXPERTS_STATISTIC_FILE', + 'RANKS' + '_PER_NODES', + 'EPLB' + '_NUM_REDUNDANT_EXPERTS', + ] + for env_var in old_env_vars: + assert f"'{env_var}'" not in envs_text + + +def test_imports_do_not_require_removed_or_ep_only_packages(monkeypatch): + real_import = builtins.__import__ + blocked_package = 'dl' + 'blas' + + def guarded_import(name, *args, **kwargs): + if (name == blocked_package or name.startswith(blocked_package + '.') or name == 'deep_gemm' + or name.startswith('deep_gemm.')): + raise AssertionError(f'unexpected optional package import: {name}') + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, '__import__', guarded_import) + modules = [ + 'lmdeploy.pytorch.backends.cuda.moe', + 'lmdeploy.pytorch.backends.cuda.moe.default', + 'lmdeploy.pytorch.backends.cuda.moe.blocked_fp8', + 'lmdeploy.pytorch.backends.cuda.graph_runner', + 'lmdeploy.pytorch.nn.eplb', + 'lmdeploy.pytorch.check_env.dist', + ] + for module in modules: + importlib.import_module(module) + + +def test_eplb_global_metadata_uses_explicit_runtime_errors(monkeypatch): + from lmdeploy.pytorch.nn import eplb + + monkeypatch.setattr(eplb, '_global_eplb_metadata', None) + with pytest.raises(RuntimeError, match='not been initialized'): + eplb.get_global_eplb_metadata() + + monkeypatch.setattr(eplb, '_global_eplb_metadata', object()) + with pytest.raises(RuntimeError, match='already been initialized'): + eplb.init_global_eplb_metadata(ep_size=1, num_routed_experts=1, num_hidden_layers=1) + + +def test_fp8_ep_prefill_quant_uses_configured_dtype_and_scale_fmt(monkeypatch): + from lmdeploy.pytorch.backends.cuda.moe import blocked_fp8 + + calls = [] + + def fake_quant(x, block_size, dtype=None, scale_fmt=None): + calls.append((x, block_size, dtype, scale_fmt)) + return 'quant', 'scale' + + monkeypatch.setattr(blocked_fp8, 'per_token_group_quant_fp8', fake_quant) + + fusedmoe = blocked_fp8.FusedMoENormal.__new__(blocked_fp8.FusedMoENormal) + fusedmoe.block_size = 64 + fusedmoe.fp8_dtype = torch.float8_e5m2 + fusedmoe.scale_fmt = 'ue8m0' + + hidden_states = object() + assert fusedmoe.per_token_group_quant_fp8(hidden_states) == ('quant', 'scale') + assert calls[-1] == (hidden_states, 64, torch.float8_e5m2, 'ue8m0') + + assert fusedmoe.per_token_group_quant_fp8(hidden_states, dtype=torch.float8_e4m3fn, scale_fmt=None) == ('quant', + 'scale') + assert calls[-1] == (hidden_states, 64, torch.float8_e4m3fn, 'ue8m0') + + +def test_fp8_ep_builder_passes_activation_dtype_and_scale_fmt(monkeypatch): + from lmdeploy.pytorch.backends.cuda.moe import blocked_fp8 + + calls = [] + + def fake_build_deepep_moe(*args, **kwargs): + calls.append((args, kwargs)) + return 'moe' + + monkeypatch.setattr(blocked_fp8, 'build_deepep_moe', fake_build_deepep_moe) + impl = blocked_fp8.FusedDeepEpMoEBlockedF8Impl.__new__(blocked_fp8.FusedDeepEpMoEBlockedF8Impl) + impl.ep_size = 2 + impl.ep_group = object() + impl.num_experts = 8 + impl.hidden_dim = 16 + impl.block_size = 64 + impl.top_k = 2 + impl.out_dtype = torch.bfloat16 + impl.fp8_dtype = torch.float8_e5m2 + impl.scale_fmt = 'ue8m0' + impl.num_max_dispatch_tokens_per_rank = 256 + impl.layer_idx = 3 + + assert blocked_fp8.FusedDeepEpMoEBlockedF8Impl.fusedmoe_build(impl, low_latency_mode=False) == 'moe' + + assert calls[0][1]['fp8_dtype'] == torch.float8_e5m2 + assert calls[0][1]['scale_fmt'] == 'ue8m0' + assert calls[0][1]['num_max_dispatch_tokens_per_rank'] == 256 + + +def test_bf16_ep_builder_passes_low_latency_token_limit(monkeypatch): + from lmdeploy.pytorch.backends.cuda.moe import default + + calls = [] + + def fake_build_deepep_moe(*args, **kwargs): + calls.append((args, kwargs)) + return 'moe' + + monkeypatch.setattr(default, 'build_deepep_moe', fake_build_deepep_moe) + impl = default.FusedMoEEPImpl.__new__(default.FusedMoEEPImpl) + impl.ep_size = 2 + impl.ep_group = object() + impl.num_experts = 8 + impl.hidden_dim = 16 + impl.top_k = 2 + impl.layer_idx = 3 + impl.out_dtype = torch.bfloat16 + impl.num_max_dispatch_tokens_per_rank = 256 + + assert default.FusedMoEEPImpl.fusedmoe_build(impl, low_latency_mode=True) == 'moe' + + assert calls[0][1]['num_max_dispatch_tokens_per_rank'] == 256 + + +def test_blocked_fp8_async_prefill_passes_weight_dtype_and_scale_fmt(): + from lmdeploy.pytorch.nn.moe.base import MoeType + from lmdeploy.pytorch.nn.moe.blocked_fp8 import FusedMoEBlockedF8 + + class FakeWeight: + dtype = torch.float8_e5m2 + + class FakeGateUp: + weight = FakeWeight() + + class FakeFusedMoE: + + def __init__(self): + self.quant_args = None + + def per_token_group_quant_fp8(self, hidden_states, dtype=None, scale_fmt=None): + self.quant_args = (hidden_states, dtype, scale_fmt) + return ('quant', 'scale') + + def capture(self): + return 'event' + + layer = FusedMoEBlockedF8.__new__(FusedMoEBlockedF8) + layer.scale_fmt = 'ue8m0' + layer.gate_up = FakeGateUp() + fusedmoe = FakeFusedMoE() + layer.fusedmoe_build = lambda low_latency_mode=False: fusedmoe + hidden_states = object() + state = {'moe_type': MoeType.DSAsyncPrefill, 'hidden_states': hidden_states} + + out_state = FusedMoEBlockedF8.before_dispatch(layer, state) + + assert fusedmoe.quant_args == (hidden_states, torch.float8_e5m2, 'ue8m0') + assert out_state['hidden_states'] == ('quant', 'scale') + assert out_state['previous_event'] == 'event'