@@ -81,8 +81,9 @@ def __init__(
8181 layer_index : int = 0 ,
8282 top_k : int = 8 ,
8383 out_dtype : torch .dtype = torch .bfloat16 ,
84+ num_max_dispatch_tokens_per_rank : int = 128 ,
8485 ):
85- from dlblas . layers . moe .token_dispatcher import DeepEPTokenDispatcherNormal
86+ from lmdeploy . pytorch . backends . cuda .token_dispatcher import DeepEPTokenDispatcherNormal
8687 self .layer_index = layer_index
8788 self .top_k = top_k
8889 self .num_experts = num_experts
@@ -94,6 +95,7 @@ def __init__(
9495 num_local_experts = self .num_local_experts ,
9596 hidden_size = hidden_dim ,
9697 params_dtype = out_dtype ,
98+ num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank ,
9799 )
98100
99101 def forward (
@@ -148,7 +150,7 @@ def fusedmoe_forward(self, state, up_weight, down_weight):
148150
149151
150152def _disposible_tensor (tensor ):
151- from dlblas . utils . utils import DisposibleTensor
153+ from lmdeploy . pytorch . backends . cuda . token_dispatcher import DisposibleTensor
152154 if isinstance (tensor , torch .Tensor ):
153155 tensor = DisposibleTensor (tensor )
154156 else :
@@ -237,8 +239,9 @@ def __init__(
237239 hidden_dim : int ,
238240 layer_index : int ,
239241 out_dtype : torch .dtype = torch .bfloat16 ,
242+ num_max_dispatch_tokens_per_rank : int = 128 ,
240243 ):
241- from dlblas . layers . moe .token_dispatcher import DeepEPTokenDispatcherLowLatency
244+ from lmdeploy . pytorch . backends . cuda .token_dispatcher import DeepEPTokenDispatcherLowLatency
242245 self .num_experts = num_experts
243246 self .layer_index = layer_index
244247 self .out_dtype = out_dtype
@@ -248,6 +251,7 @@ def __init__(
248251 num_local_experts = num_experts // ep_size ,
249252 hidden_size = hidden_dim ,
250253 params_dtype = out_dtype ,
254+ num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank ,
251255 )
252256
253257 def experts (
@@ -258,8 +262,7 @@ def experts(
258262 masked_m : torch .Tensor ,
259263 expected_m : int ,
260264 ):
261- from dlblas .utils .utils import DisposibleTensor
262-
265+ from lmdeploy .pytorch .backends .cuda .token_dispatcher import DisposibleTensor
263266 from lmdeploy .pytorch .kernels .cuda .activation import silu_and_mul_moe_ep
264267 from lmdeploy .pytorch .third_party .deep_gemm import m_grouped_bf16_gemm_nt_masked
265268 num_groups , m , _ = hidden_states .shape
@@ -339,22 +342,25 @@ def build_deepep_moe(
339342 top_k : int ,
340343 layer_idx : int = 0 ,
341344 out_dtype : torch .dtype = torch .bfloat16 ,
345+ num_max_dispatch_tokens_per_rank : int = 128 ,
342346):
343347 if low_latency_mode :
344348 return FusedMoELowLatency (ep_size = ep_size ,
345349 ep_group = ep_group ,
346350 num_experts = num_experts ,
347351 hidden_dim = hidden_dim ,
348352 layer_index = layer_idx ,
349- out_dtype = out_dtype )
353+ out_dtype = out_dtype ,
354+ num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank )
350355 else :
351356 return FusedMoENormal (ep_size = ep_size ,
352357 ep_group = ep_group ,
353358 num_experts = num_experts ,
354359 hidden_dim = hidden_dim ,
355360 layer_index = layer_idx ,
356361 top_k = top_k ,
357- out_dtype = out_dtype )
362+ out_dtype = out_dtype ,
363+ num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank )
358364
359365
360366class FusedMoEEPImpl (TritonFusedMoEImpl ):
@@ -370,6 +376,7 @@ def __init__(
370376 renormalize : bool = False ,
371377 layer_idx : int = 0 ,
372378 out_dtype : torch .dtype = torch .bfloat16 ,
379+ num_max_dispatch_tokens_per_rank : int = 128 ,
373380 ):
374381 super ().__init__ (top_k , num_experts , renormalize )
375382 self .num_experts = num_experts
@@ -378,19 +385,21 @@ def __init__(
378385 self .hidden_dim = hidden_dim
379386 self .layer_idx = layer_idx
380387 self .out_dtype = out_dtype
388+ self .num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
381389
382390 try :
383391 import deep_gemm # noqa: F401
384392 except ImportError :
385393 logger .exception ('DeepGEMM is required for DeepEP MoE implementation.' )
394+ raise
386395
387- try :
388- from dlblas . layers . moe . token_dispatcher import DeepEPBuffer , DeepEPMode , use_deepep # noqa: F401
389- get_deepep_state (). enable ()
390- if hasattr ( DeepEPBuffer , 'set_explicitly_destroy' ):
391- DeepEPBuffer . set_explicitly_destroy ()
392- except ImportError :
393- logger . warning ( 'For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP' )
396+ from lmdeploy . pytorch . backends . cuda . token_dispatcher import DeepEPBuffer , use_deepep
397+ if not use_deepep :
398+ raise ImportError ( 'DeepEP is required for DeepEP MoE implementation. Please install '
399+ 'https://github.com/deepseek-ai/DeepEP.' )
400+ get_deepep_state (). enable ()
401+ if hasattr ( DeepEPBuffer , 'set_explicitly_destroy' ) :
402+ DeepEPBuffer . set_explicitly_destroy ( )
394403
395404 # pre-allocate buffer
396405 self .fusedmoe_build (True )
@@ -440,7 +449,8 @@ def fusedmoe_build(self, low_latency_mode: bool = False):
440449 self .hidden_dim ,
441450 self .top_k ,
442451 layer_idx = self .layer_idx ,
443- out_dtype = self .out_dtype )
452+ out_dtype = self .out_dtype ,
453+ num_max_dispatch_tokens_per_rank = self .num_max_dispatch_tokens_per_rank )
444454 return deepep_moe
445455
446456
@@ -457,6 +467,7 @@ def build(
457467 ep_group : dist .ProcessGroup = None ,
458468 layer_idx : int = 0 ,
459469 out_dtype : torch .dtype = torch .bfloat16 ,
470+ num_max_dispatch_tokens_per_rank : int = 128 ,
460471 ):
461472 """Build from mlp."""
462473 if ep_size > 1 :
@@ -467,5 +478,6 @@ def build(
467478 hidden_dim = hidden_dim ,
468479 renormalize = renormalize ,
469480 layer_idx = layer_idx ,
470- out_dtype = out_dtype )
481+ out_dtype = out_dtype ,
482+ num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank )
471483 return TritonFusedMoEImpl (top_k = top_k , num_experts = num_experts , renormalize = renormalize )
0 commit comments