Skip to content

Commit 0c05620

Browse files
committed
lint
1 parent 1e7343b commit 0c05620

1 file changed

Lines changed: 180 additions & 0 deletions

File tree

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#!/usr/bin/env python3
2+
"""Model-level decode benchmark for Qwen3.5 MoE split-K SDPA.
3+
4+
Measures prefill tok/s and decode tok/s across different prompt sizes
5+
and decode lengths to evaluate FlashDecoding++ async softmax impact.
6+
"""
7+
8+
import json
9+
import sys
10+
import time
11+
12+
import torch
13+
14+
# Register Triton kernels before model import
15+
import executorch.backends.cuda.triton.kernels # noqa: F401
16+
17+
from executorch.examples.models.qwen3_5_moe.export import load_prequantized_model
18+
19+
20+
PROMPT_SIZES = [1, 15, 59, 143, 1694]
21+
DECODE_LENGTHS = [16, 64, 256, 1024]
22+
MODEL_PATH = "/home/gasoonjia/models/qwen35_moe_int4_hqq"
23+
NUM_WARMUP = 2 # warmup runs before timing
24+
25+
26+
def _move_to_cuda(model, config):
27+
for fqn, buf in list(model.named_buffers()):
28+
parts = fqn.rsplit(".", 1)
29+
parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
30+
if buf.device.type == "meta":
31+
dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool
32+
parent.register_buffer(
33+
parts[-1], torch.zeros(buf.shape, dtype=dtype, device="cuda")
34+
)
35+
else:
36+
parent.register_buffer(parts[-1], buf.to("cuda"))
37+
38+
for name, p in model.named_parameters():
39+
parts = name.rsplit(".", 1)
40+
parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
41+
setattr(
42+
parent,
43+
parts[-1],
44+
torch.nn.Parameter(p.data.to("cuda"), requires_grad=False),
45+
)
46+
47+
for layer in model.layers:
48+
if hasattr(layer.attn, "rotary_emb"):
49+
rope = layer.attn.rotary_emb
50+
inv_freq = 1.0 / (
51+
config.rope_theta
52+
** (
53+
torch.arange(0, rope.rotary_dim, 2, dtype=torch.float32)
54+
/ rope.rotary_dim
55+
)
56+
)
57+
rope.inv_freq = inv_freq.to("cuda")
58+
if hasattr(layer.attn, "mask"):
59+
layer.attn.register_buffer(
60+
"mask",
61+
torch.tril(
62+
torch.ones(
63+
config.max_seq_len,
64+
config.max_seq_len,
65+
dtype=torch.bool,
66+
device="cuda",
67+
)
68+
),
69+
)
70+
71+
72+
def _reset_state(model):
73+
"""Reset all KV caches, conv_state, and recurrent_state to zero."""
74+
for layer in model.layers:
75+
attn = layer.attn
76+
if hasattr(attn, "kv_cache"):
77+
attn.kv_cache.k_cache.zero_()
78+
attn.kv_cache.v_cache.zero_()
79+
if hasattr(attn, "conv_state"):
80+
attn.conv_state.zero_()
81+
if hasattr(attn, "recurrent_state"):
82+
attn.recurrent_state.zero_()
83+
84+
85+
@torch.inference_mode()
86+
def benchmark_prefill(model, prompt_size):
87+
"""Prefill prompt_size tokens one at a time, return tok/s."""
88+
_reset_state(model)
89+
tokens = torch.randint(0, 1000, (1, 1), device="cuda", dtype=torch.long)
90+
91+
# Warmup
92+
for i in range(min(prompt_size, NUM_WARMUP)):
93+
pos = torch.tensor([i], device="cuda")
94+
model(tokens, pos)
95+
96+
_reset_state(model)
97+
torch.cuda.synchronize()
98+
t0 = time.perf_counter()
99+
100+
for i in range(prompt_size):
101+
pos = torch.tensor([i], device="cuda")
102+
model(tokens, pos)
103+
104+
torch.cuda.synchronize()
105+
elapsed = time.perf_counter() - t0
106+
return prompt_size / elapsed if elapsed > 0 else 0.0
107+
108+
109+
@torch.inference_mode()
110+
def benchmark_decode(model, prompt_size, decode_length):
111+
"""Prefill prompt_size tokens, then decode decode_length tokens. Return decode tok/s."""
112+
_reset_state(model)
113+
tokens = torch.randint(0, 1000, (1, 1), device="cuda", dtype=torch.long)
114+
115+
# Prefill
116+
for i in range(prompt_size):
117+
pos = torch.tensor([i], device="cuda")
118+
logits = model(tokens, pos)
119+
120+
# Get first decode token
121+
next_token = logits[:, -1:, :].argmax(dim=-1)
122+
123+
# Warmup decode
124+
# (we skip warmup here to avoid polluting KV cache beyond prompt_size + decode_length)
125+
126+
torch.cuda.synchronize()
127+
t0 = time.perf_counter()
128+
129+
for i in range(decode_length):
130+
pos = torch.tensor([prompt_size + i], device="cuda")
131+
logits = model(next_token, pos)
132+
next_token = logits[:, -1:, :].argmax(dim=-1)
133+
134+
torch.cuda.synchronize()
135+
elapsed = time.perf_counter() - t0
136+
return decode_length / elapsed if elapsed > 0 else 0.0
137+
138+
139+
def main():
140+
max_seq = max(PROMPT_SIZES) + max(DECODE_LENGTHS) + 16
141+
print(f"Loading model from {MODEL_PATH} (max_seq_len={max_seq})...")
142+
model, config = load_prequantized_model(MODEL_PATH, max_seq_len=max_seq)
143+
_move_to_cuda(model, config)
144+
model.eval()
145+
146+
results = {"prefill": {}, "decode": {}}
147+
148+
# Prefill benchmark
149+
print("\n=== Prefill Benchmark ===")
150+
print(f"{'Prompt Size':>12} | {'tok/s':>10}")
151+
print("-" * 27)
152+
for ps in PROMPT_SIZES:
153+
tps = benchmark_prefill(model, ps)
154+
results["prefill"][ps] = round(tps, 2)
155+
print(f"{ps:>12} | {tps:>10.2f}")
156+
157+
# Decode benchmark
158+
print("\n=== Decode Benchmark ===")
159+
header = f"{'Prompt Size':>12}"
160+
for dl in DECODE_LENGTHS:
161+
header += f" | {'dec=' + str(dl):>12}"
162+
print(header)
163+
print("-" * len(header))
164+
165+
for ps in PROMPT_SIZES:
166+
row = f"{ps:>12}"
167+
results["decode"][ps] = {}
168+
for dl in DECODE_LENGTHS:
169+
tps = benchmark_decode(model, ps, dl)
170+
results["decode"][ps][dl] = round(tps, 2)
171+
row += f" | {tps:>12.2f}"
172+
print(row)
173+
174+
# Dump JSON for easy comparison
175+
print("\n--- JSON ---")
176+
print(json.dumps(results, indent=2))
177+
178+
179+
if __name__ == "__main__":
180+
main()

0 commit comments

Comments
 (0)