Skip to content

Commit 63a0ab8

Browse files
committed
test(mtp): MTP unit tests + static benchmark
Behavioural/CUDA coverage for the subtle MTP paths: verify-extra-state metadata, decode CUDA-graph verify layouts, fa3 fp8 verify narrowing, GDN verify equivalence, the spec causal_conv1d kernel and its prefill->decode roundtrip, and the linear-att conv/SSM widened-slot split + snapshot + CPU-cache persistence. Also extends the static-inference MTP benchmark and anchors the .gitignore benchmark-output rule to /benchmark.
1 parent f535bbe commit 63a0ab8

15 files changed

Lines changed: 1483 additions & 71 deletions

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ dist
77
.vscode
88
tmp/
99
requirements-musa.txt
10-
logs/
10+
logs/
11+
12+
/benchmark/

test/benchmark/static_inference/model_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_model_inference(args):
3636
"graph_max_len_in_batch": args.max_req_total_len,
3737
"graph_max_batch_size": args.graph_max_batch_size,
3838
"mem_fraction": args.mem_fraction,
39-
"max_req_num": 2048,
39+
"max_req_num": 512,
4040
"batch_max_tokens": 1024,
4141
"run_mode": "normal",
4242
"max_seq_length": args.max_req_total_len,

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 216 additions & 66 deletions
Large diffs are not rendered by default.

test/benchmark/static_inference/test_model.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,29 @@
1111
from lightllm.utils.config_utils import get_config_json, get_dtype
1212

1313

14+
def parse_batch_size(value):
15+
parts = [part.strip() for part in value.split(",") if part.strip()]
16+
if not parts:
17+
raise ValueError("batch_size must contain at least one integer")
18+
19+
batch_sizes = []
20+
for part in parts:
21+
size = int(part)
22+
if size <= 0:
23+
raise ValueError("batch_size values must be positive integers")
24+
batch_sizes.append(size)
25+
26+
if len(batch_sizes) == 1:
27+
return batch_sizes[0]
28+
return batch_sizes
29+
30+
1431
class TestModelInfer(unittest.TestCase):
1532
def test_model_infer(self):
1633
args = get_env_start_args()
1734
if args.data_type is None:
1835
args.data_type = get_dtype(args.model_dir)
19-
if args.mtp_mode == "deepseekv3":
36+
if args.mtp_mode is not None:
2037
test_model_inference_mtp(args)
2138
else:
2239
test_model_inference(args)
@@ -27,7 +44,7 @@ def test_model_infer(self):
2744
import torch
2845

2946
parser = make_argument_parser()
30-
parser.add_argument("--batch_size", type=int, default=None, help="batch size")
47+
parser.add_argument("--batch_size", type=parse_batch_size, default=None, help="batch size, e.g. 8 or 1,2,4,8")
3148
parser.add_argument("--input_len", type=int, default=64, help="input sequence length")
3249
parser.add_argument("--output_len", type=int, default=128, help="output sequence length")
3350
parser.add_argument(

test/cpu_cache_kernel/test_speed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
buffer_count = triton.cdiv(SEQ_LEN, big_page_token_num) + 2 # matches Qwen3NextMemManager
105105

106106

107-
conv_shape = linear_config.get_conv_state_shape()
107+
conv_shape = linear_config.get_persisted_conv_state_shape()
108108
cpu_kv_conv_state = torch.empty(
109109
(buffer_count, linear_config.linear_layer_num, *conv_shape),
110110
dtype=linear_config.conv_state_dtype,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import types
2+
import torch
3+
import pytest
4+
5+
import lightllm.common.basemodel.attention.fa3.fp8 as fp8_mod
6+
from lightllm.common.basemodel.attention.fa3.fp8 import Fp8Fa3DecodeAttState
7+
8+
9+
def _make_verify_state(n_real, mtp_size, head_num=2, head_dim=8):
10+
"""Build an Fp8Fa3DecodeAttState as init_state would leave it in MTP-verify mode,
11+
bypassing init_state. b_att_seq_len/page_table are NARROW (n_real); infer_state.b_seq_len
12+
is the FULL expanded tensor (n_real*mtp_size) that must NOT be used as cache_seqlens."""
13+
state = object.__new__(Fp8Fa3DecodeAttState)
14+
batch = n_real * mtp_size
15+
state.b_att_seq_len = torch.full((n_real,), 16, dtype=torch.int32)
16+
state.page_table = torch.zeros((n_real, 16), dtype=torch.int32)
17+
state.cu_seqlens_q = torch.arange(0, (n_real + 1) * mtp_size, mtp_size, dtype=torch.int32)
18+
state.cu_seqlens_k = torch.zeros((n_real + 1,), dtype=torch.int32)
19+
state.decode_max_q_seq_len = mtp_size
20+
state.infer_state = types.SimpleNamespace(
21+
b_seq_len=torch.full((batch,), 16, dtype=torch.int32),
22+
batch_size=batch,
23+
)
24+
# k/v descale sized per real request (att_batch_size), indexed by layer
25+
state.k_descale = torch.ones((1, n_real, head_num))
26+
state.v_descale = torch.ones((1, n_real, head_num))
27+
state.backend = types.SimpleNamespace(_find_layer_index=lambda k, v, att_state: 0)
28+
return state, batch
29+
30+
31+
def test_fp8_decode_uses_narrowed_cache_seqlens_and_causal(monkeypatch):
32+
n_real, mtp_size, head_num, head_dim = 3, 4, 2, 8
33+
state, batch = _make_verify_state(n_real, mtp_size, head_num, head_dim)
34+
35+
captured = {}
36+
37+
def fake_flash(**kwargs):
38+
captured.update(kwargs)
39+
q = kwargs["q"]
40+
return torch.zeros((q.shape[0], q.shape[1], q.shape[2]))
41+
42+
def fake_quant(x, use_per_token_if_dynamic=True):
43+
return x, torch.ones((x.shape[0], 1))
44+
45+
monkeypatch.setattr(fp8_mod, "flash_attn_with_kvcache", fake_flash)
46+
monkeypatch.setattr(fp8_mod, "scaled_fp8_quant", fake_quant)
47+
48+
q = torch.randn((batch, head_num, head_dim))
49+
k = torch.randn((batch, head_num, head_dim))
50+
v = torch.randn((batch, head_num, head_dim))
51+
52+
state._fp8_decode_att(q=q, k=k, v=v)
53+
54+
# The KV-side seqlens must be the NARROW per-real-request tensor, matching page_table rows.
55+
assert captured["cache_seqlens"] is state.b_att_seq_len
56+
assert captured["cache_seqlens"].shape[0] == n_real
57+
assert captured["cache_seqlens"].shape[0] == captured["page_table"].shape[0]
58+
# Verify decode must be causal, like the non-fp8 sibling.
59+
assert captured["causal"] is True

0 commit comments

Comments
 (0)