99from lightllm .models import get_model
1010from lightllm .common .basemodel .batch_objs import ModelInput , ModelOutput
1111from lightllm .server .core .objs .start_args_type import StartArgs
12- from torch .profiler import profile , record_function , ProfilerActivity
12+ from torch .profiler import profile , ProfilerActivity
1313from lightllm .utils .log_utils import init_logger
1414from lightllm .models .deepseek_mtp .model import Deepseek3MTPModel
15- import torch .cuda as cuda
15+ from lightllm .models .qwen3_moe_mtp .model import Qwen3MOEMTPModel
16+ from lightllm .models .mistral_mtp .model import MistralMTPModel
17+ from lightllm .models .glm4_moe_lite_mtp .model import Glm4MoeLiteMTPModel
1618
1719logger = init_logger (__name__ )
1820
1921
2022def init_mtp_model (args : StartArgs , kvargs , main_model ):
21- mtp_step = args .mtp_step
2223 draft_models = []
2324
2425 os .environ ["DISABLE_CHECK_MAX_LEN_INFER" ] = "1"
25- mtp_model_kvargs = kvargs
26- mtp_model_kvargs .update (
27- {
28- "weight_dir" : args .mtp_draft_model_dir ,
26+
27+ if args .mtp_mode in ["vanilla_with_att" , "vanilla_no_att" ]:
28+ num_mtp_modules = args .mtp_step
29+ elif args .mtp_mode in ["eagle_with_att" , "eagle_no_att" ]:
30+ num_mtp_modules = 1
31+ else :
32+ assert False , f"error mtp mode { args .mtp_mode } "
33+
34+ for i in range (num_mtp_modules ):
35+ mtp_model_cfg , _ = PretrainedConfig .get_config_dict (args .mtp_draft_model_dir [i ])
36+ model_type = mtp_model_cfg .get ("model_type" , "" )
37+ mtp_model_kvargs = {
38+ "weight_dir" : args .mtp_draft_model_dir [i ],
2939 "max_total_token_num" : main_model .mem_manager .size ,
30- "disable_chunked_prefill" : True ,
31- "mtp_mode" : args .mtp_mode ,
40+ "load_way" : kvargs ["load_way" ],
41+ "max_req_num" : kvargs .get ("max_req_num" , 1000 ),
42+ "max_seq_length" : kvargs .get ("max_seq_length" , 1024 * 5 ),
43+ "is_token_healing" : False ,
44+ "return_all_prompt_logics" : False ,
45+ "disable_chunked_prefill" : args .disable_chunked_prefill ,
46+ "data_type" : kvargs .get ("data_type" , "float16" ),
47+ "graph_max_batch_size" : kvargs .get ("graph_max_batch_size" , 16 ),
48+ "graph_max_len_in_batch" : kvargs .get ("graph_max_len_in_batch" , 8196 ),
49+ "disable_cudagraph" : kvargs .get ("disable_cudagraph" , False ),
50+ "mem_fraction" : kvargs ["mem_fraction" ],
51+ "batch_max_tokens" : kvargs .get ("batch_max_tokens" , None ),
52+ "quant_type" : kvargs .get ("quant_type" , None ),
53+ "quant_cfg" : kvargs .get ("quant_cfg" , None ),
54+ "run_mode" : "normal" ,
55+ "llm_prefill_att_backend" : kvargs .get ("llm_prefill_att_backend" , args .llm_prefill_att_backend ),
56+ "llm_decode_att_backend" : kvargs .get ("llm_decode_att_backend" , args .llm_decode_att_backend ),
57+ "vit_att_backend" : kvargs .get ("vit_att_backend" , args .vit_att_backend ),
58+ "llm_kv_type" : kvargs .get ("llm_kv_type" , args .llm_kv_type ),
59+ "llm_kv_quant_group_size" : kvargs .get ("llm_kv_quant_group_size" , args .llm_kv_quant_group_size ),
3260 "main_model" : main_model ,
61+ "mtp_previous_draft_models" : draft_models .copy (),
62+ "mtp_mode" : args .mtp_mode ,
3363 }
34- )
35- for i in range (mtp_step ):
36- mtp_model_cfg , _ = PretrainedConfig .get_config_dict (args .mtp_draft_model_dir )
37- mtp_model_kvargs .update (
38- {
39- "weight_dir" : args .spec_model_dir ,
40- "max_total_token_num" : main_model .mem_manager .size ,
41- "disable_chunked_prefill" : True ,
42- "mtp_mode" : args .mtp_mode ,
43- "main_model" : main_model ,
44- "mem_layer_start" : main_model .config ["num_hidden_layers" ] + i * mtp_model_cfg ["num_hidden_layers" ],
45- }
46- )
47- draft_models .append (Deepseek3MTPModel (mtp_model_kvargs ))
64+
65+ if model_type == "deepseek_v3" :
66+ assert args .mtp_mode in ["vanilla_with_att" , "eagle_with_att" ]
67+ draft_models .append (Deepseek3MTPModel (mtp_model_kvargs ))
68+ elif model_type == "qwen3_moe" :
69+ assert args .mtp_mode in ["vanilla_no_att" , "eagle_no_att" ]
70+ draft_models .append (Qwen3MOEMTPModel (mtp_model_kvargs ))
71+ elif model_type == "mistral" :
72+ assert args .mtp_mode in ["vanilla_no_att" , "eagle_no_att" ]
73+ draft_models .append (MistralMTPModel (mtp_model_kvargs ))
74+ elif mtp_model_cfg ["model_type" ] == "glm4_moe_lite" :
75+ assert args .mtp_mode in ["vanilla_with_att" , "eagle_with_att" ]
76+ draft_models .append (Glm4MoeLiteMTPModel (mtp_model_kvargs ))
77+ elif model_type in ("qwen3_5" , "qwen3_5_text" ):
78+ assert args .mtp_mode in ["vanilla_with_att" , "eagle_with_att" ]
79+ from lightllm .models .qwen3_5_mtp .model import Qwen3_5MTPModel
80+
81+ draft_models .append (Qwen3_5MTPModel (mtp_model_kvargs ))
82+ elif model_type in ("qwen3_5_moe" , "qwen3_5_moe_text" ):
83+ assert args .mtp_mode in ["vanilla_with_att" , "eagle_with_att" ]
84+ from lightllm .models .qwen3_5_moe_mtp .model import Qwen3_5MoeMTPModel
85+
86+ draft_models .append (Qwen3_5MoeMTPModel (mtp_model_kvargs ))
87+ else :
88+ raise ValueError (f"Unsupported MTP model type: { model_type } " )
89+
90+ logger .info (f"loaded mtp model class { draft_models [i ].__class__ } " )
4891 return draft_models
4992
5093
@@ -68,13 +111,22 @@ def test_model_inference_mtp(args):
68111 "max_total_token_num" : args .max_total_token_num ,
69112 "graph_max_len_in_batch" : args .max_req_total_len ,
70113 "graph_max_batch_size" : args .graph_max_batch_size ,
71- "mem_faction" : args .mem_fraction ,
72- "max_req_num" : 2000 ,
114+ "mem_fraction" : args .mem_fraction ,
115+ # Static bench runs explicit batch sizes (<= a few hundred). The hybrid Qwen3.5
116+ # GDN req-state cache is sized max_req_num * (mtp_step + 1) at ~34 MB/slot, so the
117+ # old default of 2000 alloc'd ~140 GB and OOM'd under MTP. 512 covers any realistic
118+ # static batch sweep while keeping the GDN cache small.
119+ "max_req_num" : 512 ,
73120 "batch_max_tokens" : 2048 ,
74121 "run_mode" : "normal" ,
75122 "max_seq_length" : args .max_req_total_len ,
76- "spec_algo" : args .spec_algo ,
77123 "disable_cudagraph" : args .disable_cudagraph ,
124+ "quant_cfg" : args .quant_cfg ,
125+ "llm_prefill_att_backend" : args .llm_prefill_att_backend ,
126+ "llm_decode_att_backend" : args .llm_decode_att_backend ,
127+ "vit_att_backend" : args .vit_att_backend ,
128+ "llm_kv_type" : args .llm_kv_type ,
129+ "llm_kv_quant_group_size" : args .llm_kv_quant_group_size ,
78130 }
79131 proc = multiprocessing .Process (
80132 target = tppart_model_infer ,
@@ -113,28 +165,36 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
113165
114166 test_data = np .vstack ([np .random .randint (0 , 50256 , input_len ) for _ in range (batch_size )])
115167 test_data = test_data .reshape (- 1 )
116- test_data = torch .from_numpy (test_data ). cuda ()
168+ test_data = torch .from_numpy (test_data )
117169
118170 b_req_idx = torch .tensor (
119- [main_model .req_manager .alloc () for _ in range (batch_size )], dtype = torch .int32 , device = "cuda "
171+ [main_model .req_manager .alloc () for _ in range (batch_size )], dtype = torch .int32 , device = "cpu "
120172 )
121- b_seq_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda " )
122- b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda " )
173+ b_seq_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cpu " )
174+ b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cpu " )
123175 for i in range (batch_size ):
124176 b_seq_len [i ] = input_len
125177
126178 total_token_num = input_len * batch_size
127- mem_indexes = main_model .req_manager .mem_manager .alloc (test_data .shape [0 ]).cuda ()
179+ mem_indexes = main_model .req_manager .mem_manager .alloc (test_data .shape [0 ])
180+ b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 )
181+ b_prefill_start_loc = b_seq_len .cumsum (dim = 0 , dtype = torch .int32 ) - b_seq_len
128182 # Main model Prefill
129183 model_input = ModelInput (
130184 batch_size = batch_size ,
131185 total_token_num = total_token_num ,
186+ max_q_seq_len = input_len ,
187+ max_kv_seq_len = input_len ,
188+ max_cache_len = 0 ,
132189 input_ids = test_data ,
133- mem_indexes = mem_indexes ,
190+ mem_indexes_cpu = mem_indexes ,
134191 b_req_idx = b_req_idx ,
192+ b_mtp_index = b_mtp_index ,
135193 b_seq_len = b_seq_len ,
136194 is_prefill = True ,
137195 b_ready_cache_len = b_ready_cache_len ,
196+ b_prefill_start_loc = b_prefill_start_loc ,
197+ prefix_total_token_num = 0 ,
138198 multimodal_params = [{"images" : [], "audios" : []} for _ in range (batch_size )],
139199 )
140200
@@ -167,8 +227,22 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
167227
168228 torch .cuda .synchronize ()
169229
230+ # Speculative width = args.mtp_step in BOTH modes (mirrors base_backend: self.mtp_step =
231+ # args.mtp_step). The number of draft MODEL INSTANCES differs: vanilla loads mtp_step
232+ # instances (each forwarded once), eagle loads ONE instance forwarded mtp_step times
233+ # (chunked_prefill/impl.py: draft_models[_step % num_instances]). The verify batch always
234+ # expands to (mtp_step + 1) rows per request.
235+ spec_width = args .mtp_step
236+ num_instances = len (draft_models )
237+ # The draft prefill above produced (1 + num_instances) columns; pad/truncate to
238+ # (spec_width + 1) so the decode verify batch matches the server's expand width. Only the
239+ # SHAPE matters for throughput here (argmax over random inputs); token values do not.
240+ while len (draft_ids ) < spec_width + 1 :
241+ draft_ids .append (draft_ids [- 1 ])
242+ draft_ids = draft_ids [: spec_width + 1 ]
170243 decode_input_ids = np .stack (draft_ids , axis = - 1 ).reshape (- 1 )
171- decode_input_ids = torch .from_numpy (decode_input_ids ).cuda ()
244+ decode_input_ids = torch .from_numpy (decode_input_ids )
245+ mtp_step = spec_width
172246
173247 # build main decode input:
174248 nopad_b_seq_idx = []
@@ -177,35 +251,39 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
177251 nopad_max_len_in_batch = 0
178252
179253 for i in range (batch_size ):
180- nopad_b_seq_idx .append (b_req_idx [i ])
254+ nopad_b_seq_idx .append (b_req_idx [i ]. item () )
181255 seq_len = b_seq_len [i ].item ()
182256 nopad_b_seq_len .append (seq_len + 1 )
183257 nopad_total_token_num += seq_len + 1
184- nopad_max_len_in_batch = max (nopad_max_len_in_batch , b_seq_len [ i ] + 1 )
258+ nopad_max_len_in_batch = max (nopad_max_len_in_batch , seq_len + 1 )
185259
186- for step in range (len ( draft_models ) ):
187- nopad_b_seq_idx .append (b_req_idx [i ])
260+ for step in range (mtp_step ):
261+ nopad_b_seq_idx .append (b_req_idx [i ]. item () )
188262 nopad_b_seq_len .append (seq_len + step + 2 )
189263 nopad_total_token_num += seq_len + step + 2
190264 nopad_max_len_in_batch = max (nopad_max_len_in_batch , seq_len + step + 2 )
191265
192- nopad_b_seq_idx = torch .tensor (nopad_b_seq_idx , dtype = torch .int32 , device = "cuda" )
193- nopad_b_seq_len = torch .tensor (nopad_b_seq_len , dtype = torch .int32 , device = "cuda" )
194- mem_indexes = main_model .req_manager .mem_manager .alloc (batch_size * (len (draft_models ) + 1 )).cuda ()
266+ nopad_b_seq_idx = torch .tensor (nopad_b_seq_idx , dtype = torch .int32 , device = "cpu" )
267+ nopad_b_seq_len = torch .tensor (nopad_b_seq_len , dtype = torch .int32 , device = "cpu" )
268+ b_mtp_index = torch .arange (mtp_step + 1 , dtype = torch .int32 ).repeat (batch_size )
269+ mem_indexes = main_model .req_manager .mem_manager .alloc (batch_size * (mtp_step + 1 ))
195270
196271 model_input = ModelInput (
197- batch_size = batch_size * (len ( draft_models ) + 1 ),
272+ batch_size = batch_size * (mtp_step + 1 ),
198273 total_token_num = nopad_total_token_num ,
274+ max_q_seq_len = 1 ,
275+ max_kv_seq_len = nopad_max_len_in_batch ,
199276 input_ids = decode_input_ids ,
200- mem_indexes = mem_indexes ,
277+ mem_indexes_cpu = mem_indexes ,
201278 b_req_idx = nopad_b_seq_idx ,
279+ b_mtp_index = b_mtp_index ,
202280 b_seq_len = nopad_b_seq_len ,
203281 is_prefill = False ,
204- multimodal_params = [{"images" : [], "audios" : []} for _ in range (batch_size * (len ( draft_models ) + 1 ))],
282+ multimodal_params = [{"images" : [], "audios" : []} for _ in range (batch_size * (mtp_step + 1 ))],
205283 )
206284
207285 # Main decode
208- for i in range (0 , output_len , len ( draft_models ) + 1 ):
286+ for i in range (0 , output_len , mtp_step + 1 ):
209287 torch .cuda .synchronize ()
210288 step_start_time = time .time ()
211289 model_output = main_model .forward (
@@ -214,12 +292,13 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
214292 prob_out = torch .softmax (model_output .logits , dim = - 1 )
215293 predict_ids = torch .argmax (prob_out , dim = 1 , keepdim = True )
216294
217- # draft decode
295+ # draft decode: mtp_step forwards, reusing draft_models[_step % num_instances]
296+ # (eagle: one instance reused mtp_step times; vanilla: a distinct instance per step).
218297 model_input .input_ids = predict_ids .reshape (- 1 )
219298 model_input .mtp_draft_input_hiddens = model_output .mtp_main_output_hiddens
220299
221- for draft_model_id in range (len ( draft_models ) ):
222- draft_model = draft_models [draft_model_id ]
300+ for _step in range (mtp_step ):
301+ draft_model = draft_models [_step % num_instances ]
223302 model_output = draft_model .forward (
224303 model_input ,
225304 )
@@ -237,7 +316,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
237316 if get_current_rank_in_dp () == 0 and not warmup :
238317 step_time = step_end_time - step_start_time
239318 print (i , " step cost time:" , step_time * 1000 )
240- print (f"Decode throughput: { batch_size * (len ( draft_models ) + 1 ) * args .dp / step_time } tokens/s" )
319+ print (f"Decode throughput: { batch_size * (mtp_step + 1 ) * args .dp / step_time } tokens/s" )
241320
242321 main_model .mem_manager .free_all ()
243322 main_model .req_manager .free_all ()
0 commit comments