44from lightllm .distributed import dist_group_manager
55from lightllm .common .triton_utils .autotuner import Autotuner
66from lightllm .common .quantization .quantize_method import WeightPack
7- from lightllm .utils .envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
7+ from lightllm .utils .envs_utils import (
8+ get_deepep_num_max_dispatch_tokens_per_rank_prefill ,
9+ get_deepep_num_max_dispatch_tokens_per_rank_decode ,
10+ )
811from lightllm .common .basemodel .triton_kernel .fused_moe .grouped_fused_moe_ep import (
9- fused_experts_impl ,
12+ fused_experts ,
13+ get_ep_num_sms ,
1014 masked_group_gemm ,
11- _deepgemm_grouped_fp8_nt_contiguous ,
15+ deepgemm_grouped_fp8_nt_contiguous ,
16+ quantize_fused_experts_input ,
1217)
1318from lightllm .common .basemodel .triton_kernel .quantization .fp8act_quant_kernel import (
1419 per_token_group_quant_fp8 ,
@@ -72,23 +77,15 @@ def _fused_experts(
7277 router_logits : Optional [torch .Tensor ] = None ,
7378 is_prefill : Optional [bool ] = None ,
7479 ):
75- w13_weight , w13_scale = w13 .weight , w13 .weight_scale
76- w2_weight , w2_scale = w2 .weight , w2 .weight_scale
77- use_fp8_w8a8 = self .quant_method .method_name != "none"
78- output = fused_experts_impl (
80+ output = fused_experts (
7981 hidden_states = input_tensor ,
80- w1 = w13_weight ,
81- w2 = w2_weight ,
82+ w13 = w13 ,
83+ w2 = w2 ,
8284 topk_weights = topk_weights ,
8385 topk_idx = topk_ids .to (torch .long ),
8486 num_experts = self .total_expert_num_contain_redundancy , # number of all experts contain redundancy
85- buffer = dist_group_manager . ep_buffer ,
87+ quant_method = self . quant_method ,
8688 is_prefill = is_prefill ,
87- use_fp8_w8a8 = use_fp8_w8a8 ,
88- use_fp8_all2all = use_fp8_w8a8 ,
89- use_int8_w8a16 = False , # default to False
90- w1_scale = w13_scale ,
91- w2_scale = w2_scale ,
9289 previous_event = None , # for overlap
9390 )
9491 return output
@@ -118,13 +115,13 @@ def low_latency_dispatch(
118115 )
119116
120117 topk_idx = topk_idx .to (torch .long )
121- num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank ()
118+ num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode ()
122119 use_fp8_w8a8 = self .quant_method .method_name != "none"
123- recv_x , masked_m , handle , event , hook = dist_group_manager .ep_buffer .low_latency_dispatch (
124- hidden_states ,
125- topk_idx ,
126- num_max_dispatch_tokens_per_rank ,
127- self .total_expert_num_contain_redundancy ,
120+ recv_x , masked_m , handle , event , hook = dist_group_manager .ep_low_latency_buffer .low_latency_dispatch (
121+ topk_idx = topk_idx ,
122+ x = hidden_states ,
123+ num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank ,
124+ num_experts = self .total_expert_num_contain_redundancy ,
128125 use_fp8 = use_fp8_w8a8 ,
129126 async_finish = False ,
130127 return_recv_hook = True ,
@@ -155,13 +152,8 @@ def select_experts_and_quant_input(
155152 num_expert_group = n_group ,
156153 scoring_func = scoring_func ,
157154 )
158- w13_weight , w13_scale = w13 .weight , w13 .weight_scale
159- block_size_k = 0
160- if w13_weight .ndim == 3 :
161- block_size_k = w13_weight .shape [2 ] // w13_scale .shape [2 ]
162- assert block_size_k == 128 , "block_size_k must be 128"
163- qinput_tensor , input_scale = per_token_group_quant_fp8 (hidden_states , block_size_k , dtype = w13_weight .dtype )
164- return topk_weights , topk_idx .to (torch .long ), (qinput_tensor , input_scale )
155+ qinput_tensor = quantize_fused_experts_input (hidden_states , w13 , self .quant_method )
156+ return topk_weights , topk_idx .to (torch .long ), qinput_tensor
165157
166158 def dispatch (
167159 self ,
@@ -171,38 +163,26 @@ def dispatch(
171163 overlap_event : Optional [Any ] = None ,
172164 ):
173165 buffer = dist_group_manager .ep_buffer
174- # get_dispatch_layout
175- (
176- num_tokens_per_rank ,
177- num_tokens_per_rdma_rank ,
178- num_tokens_per_expert ,
179- is_token_in_rank ,
180- previous_event ,
181- ) = buffer .get_dispatch_layout (
182- topk_idx ,
183- self .total_expert_num_contain_redundancy ,
184- previous_event = overlap_event ,
185- async_finish = True ,
186- allocate_on_comm_stream = True ,
187- )
188- recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , event = buffer .dispatch (
166+ num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill ()
167+ recv_x , recv_topk_idx , recv_topk_weights , handle , event = buffer .dispatch (
189168 qinput_tensor ,
190169 topk_idx = topk_idx ,
191170 topk_weights = topk_weights ,
192- num_tokens_per_rank = num_tokens_per_rank ,
193- num_tokens_per_rdma_rank = num_tokens_per_rdma_rank ,
194- is_token_in_rank = is_token_in_rank ,
195- num_tokens_per_expert = num_tokens_per_expert ,
196- previous_event = previous_event ,
197- async_finish = True ,
198- allocate_on_comm_stream = True ,
171+ num_experts = self .total_expert_num_contain_redundancy ,
172+ num_max_tokens_per_rank = num_max_tokens_per_rank ,
199173 expert_alignment = 128 ,
174+ num_sms = get_ep_num_sms (),
175+ previous_event = overlap_event ,
176+ async_with_compute_stream = True ,
177+ allocate_on_comm_stream = True ,
178+ do_cpu_sync = True ,
179+ do_handle_copy = False ,
200180 )
201181
202182 def hook ():
203183 event .current_stream_wait ()
204184
205- return recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , hook
185+ return recv_x , recv_topk_idx , recv_topk_weights , handle . num_recv_tokens_per_expert_list , handle , hook
206186
207187 def masked_group_gemm (
208188 self ,
@@ -281,7 +261,7 @@ def prefilled_group_gemm(
281261 # groupgemm (contiguous layout)
282262 gemm_out_a = torch .empty ((all_tokens , N ), device = device , dtype = hidden_dtype )
283263
284- _deepgemm_grouped_fp8_nt_contiguous (input_tensor , (w13_weight , w13_scale ), gemm_out_a , m_indices )
264+ deepgemm_grouped_fp8_nt_contiguous (input_tensor , (w13_weight , w13_scale ), gemm_out_a , m_indices )
285265
286266 # silu_and_mul_fwd + qaunt
287267 # TODO fused kernel
@@ -295,7 +275,7 @@ def prefilled_group_gemm(
295275 # groupgemm (contiguous layout)
296276 gemm_out_b = torch .empty ((all_tokens , K ), device = device , dtype = hidden_dtype )
297277
298- _deepgemm_grouped_fp8_nt_contiguous (
278+ deepgemm_grouped_fp8_nt_contiguous (
299279 (qsilu_out , qsilu_out_scale ), (w2_weight , w2_scale ), gemm_out_b , m_indices
300280 )
301281 # gather and local reduce
@@ -319,7 +299,7 @@ def low_latency_combine(
319299 topk_weights : torch .Tensor ,
320300 handle : Any ,
321301 ):
322- combined_x , event_overlap , hook = dist_group_manager .ep_buffer .low_latency_combine (
302+ combined_x , event_overlap , hook = dist_group_manager .ep_low_latency_buffer .low_latency_combine (
323303 gemm_out_b , topk_idx , topk_weights , handle , async_finish = False , return_recv_hook = True
324304 )
325305 return combined_x , hook
@@ -335,8 +315,9 @@ def combine(
335315 gemm_out_b ,
336316 handle ,
337317 topk_weights = None ,
338- async_finish = True ,
318+ num_sms = get_ep_num_sms () ,
339319 previous_event = overlap_event ,
320+ async_with_compute_stream = True ,
340321 allocate_on_comm_stream = True ,
341322 )
342323
0 commit comments