-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathrun_ruler_mla.py
More file actions
145 lines (125 loc) · 6.73 KB
/
Copy pathrun_ruler_mla.py
File metadata and controls
145 lines (125 loc) · 6.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python
"""Run the RULER needle-in-a-haystack benchmark on GLM-4.7-Flash with vortex
sparse MLA attention (rope-aware block-sparse routing on the hand-written CUDA
decode kernel).
This is a self-contained, single-GPU example of the vortex sparse-MLA decode
path. The defaults reproduce the known-good GLM-4.7-Flash configuration
(block=page=32, topk=61, ~100% on RULER); every knob is overridable on the CLI.
cuda_mla sglang attention backend -> vortex CUDA MLA decode + flashinfer prefill
trtllm vortex indexer backend -> 2D block-table page selection
rope_aware_block_sparse_mla -> scores pages by the FULL fused-latent dot
<q_nope,c_kv> + <q_pe,c_pe>, then top-k
GLM-4.7-Flash (model type ``glm4_moe_lite``) requires transformers >= 5.0, so run
this in the ``vortex_glm`` conda env (NOT ``vortex_v1``, which can't load it):
conda activate vortex_glm
CUDA_VISIBLE_DEVICES=<free-gpu> python examples/run_ruler_mla.py
# or pin the GPU / shrink the slice:
python examples/run_ruler_mla.py --gpu 3 --n 20 --dump
The script forces HF offline mode by default (the model is expected to be cached);
pass ``--online`` to allow hub access.
"""
import argparse
import json
import os
# Must be set before sglang/vortex import these. setdefault so the caller's env
# (or the shell) still wins.
os.environ.setdefault("SGLANG_ENABLE_TORCH_COMPILE", "0")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--model", default="zai-org/GLM-4.7-Flash",
help="HF model id (default: GLM-4.7-Flash).")
p.add_argument("--module", default="rope_aware_block_sparse_mla",
help="vortex MLA flow name (default: rope_aware_block_sparse_mla).")
p.add_argument("--data", default="examples/validation.jsonl",
help="RULER jsonl with {input, outputs:[str]} rows.")
p.add_argument("--gpu", default=None,
help="GPU index to pin (sets CUDA_VISIBLE_DEVICES). Default: inherit env.")
p.add_argument("--n", type=int, default=100, help="Number of examples (default: 100).")
p.add_argument("--block", type=int, default=32, help="vortex block size == page size.")
p.add_argument("--topk", type=int, default=61, help="vortex_topk_val (selected blocks).")
p.add_argument("--max-new-tokens", type=int, default=128)
p.add_argument("--mem-fraction", type=float, default=0.85)
p.add_argument("--tp", type=int, default=1, help="tensor-parallel size.")
p.add_argument("--thinking", action="store_true",
help="Enable GLM thinking mode (default off; needles answer directly).")
p.add_argument("--disable-cuda-graph", action="store_true",
help="Run eager (no cuda graph capture).")
p.add_argument("--dense", action="store_true",
help="Disable vortex sparsity: run plain dense trtllm_mla.")
p.add_argument("--kv-cache-dtype", default="auto",
help="sglang kv_cache_dtype (auto|fp8_e4m3|bfloat16).")
p.add_argument("--online", action="store_true",
help="Allow HF hub access (default: HF_HUB_OFFLINE=1).")
p.add_argument("--dump", action="store_true",
help="Print the first few (expected, generated) pairs.")
return p.parse_args()
def main() -> None:
args = parse_args()
if args.gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
os.environ.setdefault("HF_HUB_OFFLINE", "0" if args.online else "1")
import sglang as sgl
import vortex_torch # noqa: F401 -- installs the ServerArgs / VortexConfig adapter
from transformers import AutoTokenizer
mode = "dense" if args.dense else "sparse"
print(f"[run_ruler_mla] model={args.model} mode={mode} module={args.module} "
f"block={args.block} topk={args.topk} n={args.n} "
f"kv_cache_dtype={args.kv_cache_dtype} "
f"cuda_graph={'off' if args.disable_cuda_graph else 'on'}", flush=True)
engine_kwargs = dict(
model_path=args.model,
trust_remote_code=True,
tp_size=args.tp,
page_size=args.block, # page == block (one block per page)
attention_backend="trtllm_mla", # vortex CUDA MLA decode kernel
kv_cache_dtype=args.kv_cache_dtype,
mem_fraction_static=args.mem_fraction,
disable_cuda_graph=args.disable_cuda_graph,
)
if not args.dense:
# Flat vortex_* kwargs; the adapter folds them into one VortexConfig and ships
# them across the spawn boundary. These reproduce the known-good GLM run.
engine_kwargs.update(
enable_vortex_sparsity=True,
vortex_module_name=args.module,
vortex_attention_backend="trtllm", # 2D block-table indexer
vortex_impl_backend="triton", # tensor-core indexer GeMM
vortex_use_tensor_core=True,
vortex_block_size=args.block,
vortex_topk_val=args.topk,
vortex_topk_ratio=0.0,
vortex_block_reserved_bos=1,
vortex_block_reserved_eos=2,
vortex_dtype="bfloat16",
vortex_layers_skip=[],
vortex_max_seq_lens=8192,
vortex_workload_chunk_size=64,
)
llm = sgl.Engine(**engine_kwargs)
with open(args.data, encoding="utf-8") as f:
rows = [json.loads(line) for line in f][: args.n]
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
def render(text: str) -> str:
msg = [{"role": "user", "content": text}]
try:
return tok.apply_chat_template(
msg, tokenize=False, add_generation_prompt=True,
enable_thinking=args.thinking)
except TypeError: # tokenizer without the enable_thinking kwarg
return tok.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
prompts = [render(r["input"]) for r in rows]
outs = llm.generate(
prompts, {"temperature": 0.0, "max_new_tokens": args.max_new_tokens})
hits = [rows[i]["outputs"][0] in outs[i]["text"] for i in range(len(rows))]
if args.dump:
for i in range(min(3, len(rows))):
print(f"\n--- ex{i} hit={hits[i]} expect={rows[i]['outputs'][0]!r} ---")
print(f"GEN: {outs[i]['text'][:400]!r}", flush=True)
acc = sum(hits)
tag = "dense trtllm_mla" if args.dense else f"sparse {args.module}"
print(f"\n>>> RULER {args.model.split('/')[-1]} | {tag} | kv={args.kv_cache_dtype}: "
f"{acc}/{len(rows)} = {acc / len(rows) * 100:.1f}%", flush=True)
llm.shutdown()
if __name__ == "__main__":
main()