Skip to content

Commit 0d42047

Browse files
committed
test(static_inference): generalize MTP static benchmark
- Dispatch to MTP bench whenever mtp_mode is set (was dead-coded to 'deepseekv3') - init_mtp_model: dispatch by config model_type (deepseek_v3/qwen3_moe/mistral/ glm4_moe_lite/qwen3_5/qwen3_5_moe), handle eagle (1 instance) vs vanilla (mtp_step instances); fix mem_faction typo; pass full att/kv/quant kvargs - run_forward_once: adapt to new ModelInput API (mem_indexes_cpu + CPU tensors, max_q/kv_seq_len, b_mtp_index, b_prefill_start_loc); reuse draft instances via _step % num_instances; pad/truncate draft_ids to mtp_step+1 - Cap max_req_num at 512 to avoid GDN req-state cache OOM under MTP
1 parent 16170f3 commit 0d42047

3 files changed

Lines changed: 129 additions & 50 deletions

File tree

test/benchmark/static_inference/model_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_model_inference(args):
3636
"graph_max_len_in_batch": args.max_req_total_len,
3737
"graph_max_batch_size": args.graph_max_batch_size,
3838
"mem_fraction": args.mem_fraction,
39-
"max_req_num": 2048,
39+
"max_req_num": 512,
4040
"batch_max_tokens": 1024,
4141
"run_mode": "normal",
4242
"max_seq_length": args.max_req_total_len,

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 127 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,85 @@
99
from lightllm.models import get_model
1010
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
1111
from lightllm.server.core.objs.start_args_type import StartArgs
12-
from torch.profiler import profile, record_function, ProfilerActivity
12+
from torch.profiler import profile, ProfilerActivity
1313
from lightllm.utils.log_utils import init_logger
1414
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
15-
import torch.cuda as cuda
15+
from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel
16+
from lightllm.models.mistral_mtp.model import MistralMTPModel
17+
from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel
1618

1719
logger = init_logger(__name__)
1820

1921

2022
def init_mtp_model(args: StartArgs, kvargs, main_model):
21-
mtp_step = args.mtp_step
2223
draft_models = []
2324

2425
os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1"
25-
mtp_model_kvargs = kvargs
26-
mtp_model_kvargs.update(
27-
{
28-
"weight_dir": args.mtp_draft_model_dir,
26+
27+
if args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]:
28+
num_mtp_modules = args.mtp_step
29+
elif args.mtp_mode in ["eagle_with_att", "eagle_no_att"]:
30+
num_mtp_modules = 1
31+
else:
32+
assert False, f"error mtp mode {args.mtp_mode}"
33+
34+
for i in range(num_mtp_modules):
35+
mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir[i])
36+
model_type = mtp_model_cfg.get("model_type", "")
37+
mtp_model_kvargs = {
38+
"weight_dir": args.mtp_draft_model_dir[i],
2939
"max_total_token_num": main_model.mem_manager.size,
30-
"disable_chunked_prefill": True,
31-
"mtp_mode": args.mtp_mode,
40+
"load_way": kvargs["load_way"],
41+
"max_req_num": kvargs.get("max_req_num", 1000),
42+
"max_seq_length": kvargs.get("max_seq_length", 1024 * 5),
43+
"is_token_healing": False,
44+
"return_all_prompt_logics": False,
45+
"disable_chunked_prefill": args.disable_chunked_prefill,
46+
"data_type": kvargs.get("data_type", "float16"),
47+
"graph_max_batch_size": kvargs.get("graph_max_batch_size", 16),
48+
"graph_max_len_in_batch": kvargs.get("graph_max_len_in_batch", 8196),
49+
"disable_cudagraph": kvargs.get("disable_cudagraph", False),
50+
"mem_fraction": kvargs["mem_fraction"],
51+
"batch_max_tokens": kvargs.get("batch_max_tokens", None),
52+
"quant_type": kvargs.get("quant_type", None),
53+
"quant_cfg": kvargs.get("quant_cfg", None),
54+
"run_mode": "normal",
55+
"llm_prefill_att_backend": kvargs.get("llm_prefill_att_backend", args.llm_prefill_att_backend),
56+
"llm_decode_att_backend": kvargs.get("llm_decode_att_backend", args.llm_decode_att_backend),
57+
"vit_att_backend": kvargs.get("vit_att_backend", args.vit_att_backend),
58+
"llm_kv_type": kvargs.get("llm_kv_type", args.llm_kv_type),
59+
"llm_kv_quant_group_size": kvargs.get("llm_kv_quant_group_size", args.llm_kv_quant_group_size),
3260
"main_model": main_model,
61+
"mtp_previous_draft_models": draft_models.copy(),
62+
"mtp_mode": args.mtp_mode,
3363
}
34-
)
35-
for i in range(mtp_step):
36-
mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir)
37-
mtp_model_kvargs.update(
38-
{
39-
"weight_dir": args.spec_model_dir,
40-
"max_total_token_num": main_model.mem_manager.size,
41-
"disable_chunked_prefill": True,
42-
"mtp_mode": args.mtp_mode,
43-
"main_model": main_model,
44-
"mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"],
45-
}
46-
)
47-
draft_models.append(Deepseek3MTPModel(mtp_model_kvargs))
64+
65+
if model_type == "deepseek_v3":
66+
assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
67+
draft_models.append(Deepseek3MTPModel(mtp_model_kvargs))
68+
elif model_type == "qwen3_moe":
69+
assert args.mtp_mode in ["vanilla_no_att", "eagle_no_att"]
70+
draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs))
71+
elif model_type == "mistral":
72+
assert args.mtp_mode in ["vanilla_no_att", "eagle_no_att"]
73+
draft_models.append(MistralMTPModel(mtp_model_kvargs))
74+
elif mtp_model_cfg["model_type"] == "glm4_moe_lite":
75+
assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
76+
draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs))
77+
elif model_type in ("qwen3_5", "qwen3_5_text"):
78+
assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
79+
from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel
80+
81+
draft_models.append(Qwen3_5MTPModel(mtp_model_kvargs))
82+
elif model_type in ("qwen3_5_moe", "qwen3_5_moe_text"):
83+
assert args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
84+
from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel
85+
86+
draft_models.append(Qwen3_5MoeMTPModel(mtp_model_kvargs))
87+
else:
88+
raise ValueError(f"Unsupported MTP model type: {model_type}")
89+
90+
logger.info(f"loaded mtp model class {draft_models[i].__class__}")
4891
return draft_models
4992

5093

@@ -68,13 +111,22 @@ def test_model_inference_mtp(args):
68111
"max_total_token_num": args.max_total_token_num,
69112
"graph_max_len_in_batch": args.max_req_total_len,
70113
"graph_max_batch_size": args.graph_max_batch_size,
71-
"mem_faction": args.mem_fraction,
72-
"max_req_num": 2000,
114+
"mem_fraction": args.mem_fraction,
115+
# Static bench runs explicit batch sizes (<= a few hundred). The hybrid Qwen3.5
116+
# GDN req-state cache is sized max_req_num * (mtp_step + 1) at ~34 MB/slot, so the
117+
# old default of 2000 alloc'd ~140 GB and OOM'd under MTP. 512 covers any realistic
118+
# static batch sweep while keeping the GDN cache small.
119+
"max_req_num": 512,
73120
"batch_max_tokens": 2048,
74121
"run_mode": "normal",
75122
"max_seq_length": args.max_req_total_len,
76-
"spec_algo": args.spec_algo,
77123
"disable_cudagraph": args.disable_cudagraph,
124+
"quant_cfg": args.quant_cfg,
125+
"llm_prefill_att_backend": args.llm_prefill_att_backend,
126+
"llm_decode_att_backend": args.llm_decode_att_backend,
127+
"vit_att_backend": args.vit_att_backend,
128+
"llm_kv_type": args.llm_kv_type,
129+
"llm_kv_quant_group_size": args.llm_kv_quant_group_size,
78130
}
79131
proc = multiprocessing.Process(
80132
target=tppart_model_infer,
@@ -113,28 +165,36 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
113165

114166
test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)])
115167
test_data = test_data.reshape(-1)
116-
test_data = torch.from_numpy(test_data).cuda()
168+
test_data = torch.from_numpy(test_data)
117169

118170
b_req_idx = torch.tensor(
119-
[main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
171+
[main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cpu"
120172
)
121-
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
122-
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
173+
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
174+
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
123175
for i in range(batch_size):
124176
b_seq_len[i] = input_len
125177

126178
total_token_num = input_len * batch_size
127-
mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
179+
mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0])
180+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32)
181+
b_prefill_start_loc = b_seq_len.cumsum(dim=0, dtype=torch.int32) - b_seq_len
128182
# Main model Prefill
129183
model_input = ModelInput(
130184
batch_size=batch_size,
131185
total_token_num=total_token_num,
186+
max_q_seq_len=input_len,
187+
max_kv_seq_len=input_len,
188+
max_cache_len=0,
132189
input_ids=test_data,
133-
mem_indexes=mem_indexes,
190+
mem_indexes_cpu=mem_indexes,
134191
b_req_idx=b_req_idx,
192+
b_mtp_index=b_mtp_index,
135193
b_seq_len=b_seq_len,
136194
is_prefill=True,
137195
b_ready_cache_len=b_ready_cache_len,
196+
b_prefill_start_loc=b_prefill_start_loc,
197+
prefix_total_token_num=0,
138198
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
139199
)
140200

@@ -167,8 +227,22 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
167227

168228
torch.cuda.synchronize()
169229

230+
# Speculative width = args.mtp_step in BOTH modes (mirrors base_backend: self.mtp_step =
231+
# args.mtp_step). The number of draft MODEL INSTANCES differs: vanilla loads mtp_step
232+
# instances (each forwarded once), eagle loads ONE instance forwarded mtp_step times
233+
# (chunked_prefill/impl.py: draft_models[_step % num_instances]). The verify batch always
234+
# expands to (mtp_step + 1) rows per request.
235+
spec_width = args.mtp_step
236+
num_instances = len(draft_models)
237+
# The draft prefill above produced (1 + num_instances) columns; pad/truncate to
238+
# (spec_width + 1) so the decode verify batch matches the server's expand width. Only the
239+
# SHAPE matters for throughput here (argmax over random inputs); token values do not.
240+
while len(draft_ids) < spec_width + 1:
241+
draft_ids.append(draft_ids[-1])
242+
draft_ids = draft_ids[: spec_width + 1]
170243
decode_input_ids = np.stack(draft_ids, axis=-1).reshape(-1)
171-
decode_input_ids = torch.from_numpy(decode_input_ids).cuda()
244+
decode_input_ids = torch.from_numpy(decode_input_ids)
245+
mtp_step = spec_width
172246

173247
# build main decode input:
174248
nopad_b_seq_idx = []
@@ -177,35 +251,39 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
177251
nopad_max_len_in_batch = 0
178252

179253
for i in range(batch_size):
180-
nopad_b_seq_idx.append(b_req_idx[i])
254+
nopad_b_seq_idx.append(b_req_idx[i].item())
181255
seq_len = b_seq_len[i].item()
182256
nopad_b_seq_len.append(seq_len + 1)
183257
nopad_total_token_num += seq_len + 1
184-
nopad_max_len_in_batch = max(nopad_max_len_in_batch, b_seq_len[i] + 1)
258+
nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + 1)
185259

186-
for step in range(len(draft_models)):
187-
nopad_b_seq_idx.append(b_req_idx[i])
260+
for step in range(mtp_step):
261+
nopad_b_seq_idx.append(b_req_idx[i].item())
188262
nopad_b_seq_len.append(seq_len + step + 2)
189263
nopad_total_token_num += seq_len + step + 2
190264
nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2)
191265

192-
nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda")
193-
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
194-
mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda()
266+
nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cpu")
267+
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cpu")
268+
b_mtp_index = torch.arange(mtp_step + 1, dtype=torch.int32).repeat(batch_size)
269+
mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (mtp_step + 1))
195270

196271
model_input = ModelInput(
197-
batch_size=batch_size * (len(draft_models) + 1),
272+
batch_size=batch_size * (mtp_step + 1),
198273
total_token_num=nopad_total_token_num,
274+
max_q_seq_len=1,
275+
max_kv_seq_len=nopad_max_len_in_batch,
199276
input_ids=decode_input_ids,
200-
mem_indexes=mem_indexes,
277+
mem_indexes_cpu=mem_indexes,
201278
b_req_idx=nopad_b_seq_idx,
279+
b_mtp_index=b_mtp_index,
202280
b_seq_len=nopad_b_seq_len,
203281
is_prefill=False,
204-
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (len(draft_models) + 1))],
282+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (mtp_step + 1))],
205283
)
206284

207285
# Main decode
208-
for i in range(0, output_len, len(draft_models) + 1):
286+
for i in range(0, output_len, mtp_step + 1):
209287
torch.cuda.synchronize()
210288
step_start_time = time.time()
211289
model_output = main_model.forward(
@@ -214,12 +292,13 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
214292
prob_out = torch.softmax(model_output.logits, dim=-1)
215293
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
216294

217-
# draft decode
295+
# draft decode: mtp_step forwards, reusing draft_models[_step % num_instances]
296+
# (eagle: one instance reused mtp_step times; vanilla: a distinct instance per step).
218297
model_input.input_ids = predict_ids.reshape(-1)
219298
model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens
220299

221-
for draft_model_id in range(len(draft_models)):
222-
draft_model = draft_models[draft_model_id]
300+
for _step in range(mtp_step):
301+
draft_model = draft_models[_step % num_instances]
223302
model_output = draft_model.forward(
224303
model_input,
225304
)
@@ -237,7 +316,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
237316
if get_current_rank_in_dp() == 0 and not warmup:
238317
step_time = step_end_time - step_start_time
239318
print(i, " step cost time:", step_time * 1000)
240-
print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s")
319+
print(f"Decode throughput: {batch_size * (mtp_step + 1) * args.dp / step_time} tokens/s")
241320

242321
main_model.mem_manager.free_all()
243322
main_model.req_manager.free_all()

test/benchmark/static_inference/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_model_infer(self):
1616
args = get_env_start_args()
1717
if args.data_type is None:
1818
args.data_type = get_dtype(args.model_dir)
19-
if args.mtp_mode == "deepseekv3":
19+
if args.mtp_mode is not None:
2020
test_model_inference_mtp(args)
2121
else:
2222
test_model_inference(args)

0 commit comments

Comments
 (0)