77from lightllm .utils .envs_utils import (
88 enable_diverse_mode_gqa_decode_fast_kernel ,
99 enable_triton_mtp_kernel ,
10- enable_triton_mtp_kernel ,
1110 get_diverse_max_batch_shared_group_size ,
1211 enable_dynamic_mtp_verify ,
1312 get_env_start_args
@@ -240,30 +239,3 @@ def build_mtp_shared_group_infos(
240239 if current_group :
241240 b_mark_shared_group .extend ([0 ] * (len (current_group ) - 1 ) + [len (current_group )])
242241 return torch .tensor (b_mark_shared_group , dtype = torch .int32 , device = "cpu" )
243-
244-
245- def build_mtp_shared_group_infos (
246- b_mtp_index : torch .Tensor ,
247- ) -> torch .Tensor :
248- # Similar to build_diverse_shared_group_infos,
249- # but the grouping logic is based on b_mtp_index, which indicates the MTP step of each request
250- max_batch_shared_group_size = get_diverse_max_batch_shared_group_size ()
251- b_mark_shared_group = []
252- num_reqs = b_mtp_index .shape [0 ]
253- if num_reqs == 0 :
254- return torch .zeros_like (b_mtp_index , dtype = torch .int32 , device = "cpu" )
255- current_group = []
256- for i in range (num_reqs ):
257- step = b_mtp_index [i ].item ()
258- if len (current_group ) == 0 :
259- current_group .append (i )
260- else :
261- prev_step = b_mtp_index [i - 1 ].item ()
262- if step == prev_step + 1 and len (current_group ) < max_batch_shared_group_size :
263- current_group .append (i )
264- else :
265- b_mark_shared_group .extend ([0 ] * (len (current_group ) - 1 ) + [len (current_group )])
266- current_group = [i ]
267- if current_group :
268- b_mark_shared_group .extend ([0 ] * (len (current_group ) - 1 ) + [len (current_group )])
269- return torch .tensor (b_mark_shared_group , dtype = torch .int32 , device = "cpu" )
0 commit comments