Skip to content

Commit d1c90ef

Browse files
author
Zhuoming Chen
committed
mla support
1 parent 5d2aa76 commit d1c90ef

21 files changed

Lines changed: 2115 additions & 57 deletions

File tree

AI/AGENTS.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,23 @@ Example existing flows live at the top level (`submissions/example_*`)
3535
and are not under any agent tag — those are framework reference
3636
materials, not work product.
3737

38-
### Environment — activate the `vortex_v04` conda env first
38+
### Environment — activate the `vortex_v1` conda env first
3939

4040
Every python invocation in this contract (`check_engine_config`,
4141
`run_submission_aime24.py`, the pre-flight loops in §5/§5c/§5f,
42-
the iterate driver, etc.) expects the **`vortex_v04`** conda
42+
the iterate driver, etc.) expects the **`vortex_v1`** conda
4343
environment. Activate it once at session start, before running
4444
any bash snippet below:
4545

4646
```bash
4747
source "$(conda info --base)/etc/profile.d/conda.sh"
48-
conda activate vortex_v04
49-
python -c "import sys; print(sys.executable)" # expect .../envs/vortex_v04/...
48+
conda activate vortex_v1
49+
python -c "import sys; print(sys.executable)" # expect .../envs/vortex_v1/...
5050
```
5151

5252
If `conda activate` is unavailable in the current shell, fall
53-
back to `conda run -n vortex_v04 python ...` per call. Either
54-
way, the running interpreter must be the one inside `vortex_v04`
53+
back to `conda run -n vortex_v1 python ...` per call. Either
54+
way, the running interpreter must be the one inside `vortex_v1`
5555
— a system / base / wrong-env python will fail to import
5656
`vortex_torch`'s C extension and the framework's Triton kernels.
5757

AI/developer_guides/developer_guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ vortex_torch/
128128
│ └── compiler/ # mirror of indexer/compiler
129129
├── engine/
130130
│ └── sgl.py # get_engine_from_json + check_engine_config
131-
└── third_party/sglang/v0.4.9/sglang/... # patched sglang with the VTX backend
131+
└── third_party/sglang/v0.5.9/sglang/... # patched sglang with the VTX backend
132132
```
133133

134134
Rule of thumb: **every op has one class and one codegen function**;
@@ -1363,7 +1363,7 @@ class, and returns it.
13631363

13641364
## 13. Runtime integration with sglang
13651365

1366-
Patched sglang ships in `third_party/sglang/v0.4.9/sglang/`. Two glue files:
1366+
Patched sglang ships in `third_party/sglang/v0.5.9/sglang/`. Two glue files:
13671367

13681368
### 13.1 `VTXGraphAttnBackend` (sglang/srt/layers/attention/vtx_graph_backend.py)
13691369

AI/generate_claude_folder.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,25 +146,25 @@
146146
across requests with matching prompt prefixes, corrupting
147147
Save/Load values. `check_engine_config` rejects the violation.
148148
149-
## Environment — activate the `vortex_v04` conda env first
149+
## Environment — activate the `vortex_v1` conda env first
150150
151151
Every python invocation in this project (`check_engine_config`,
152152
`run_submission_aime24.py`, the pre-flight loops in the slash
153-
commands, etc.) expects the **`vortex_v04`** conda environment.
153+
commands, etc.) expects the **`vortex_v1`** conda environment.
154154
**Activate it once at session start** before running any of the
155155
bash snippets below:
156156
157157
```bash
158158
source "$(conda info --base)/etc/profile.d/conda.sh"
159-
conda activate vortex_v04
160-
python -c "import sys; print(sys.executable)" # expect a path under .../envs/vortex_v04/
159+
conda activate vortex_v1
160+
python -c "import sys; print(sys.executable)" # expect a path under .../envs/vortex_v1/
161161
```
162162
163163
If `conda activate` isn't available in the current shell (e.g. a
164164
non-interactive sub-shell that didn't source the conda profile),
165-
fall back to `conda run -n vortex_v04 python ...` for every
165+
fall back to `conda run -n vortex_v1 python ...` for every
166166
python call. Either form is acceptable; what matters is that the
167-
running interpreter is the one inside `vortex_v04`.
167+
running interpreter is the one inside `vortex_v1`.
168168
169169
## Running the benchmark — policy
170170
@@ -413,19 +413,19 @@
413413
yet exist, create it; otherwise resume into it. Confirm the tag
414414
with the user only if you cannot determine your model name.
415415
416-
### Second action — activate the `vortex_v04` conda env
416+
### Second action — activate the `vortex_v1` conda env
417417
418418
Every python call in this workflow must run inside the
419-
**`vortex_v04`** conda env. Activate once at session start:
419+
**`vortex_v1`** conda env. Activate once at session start:
420420
421421
```bash
422422
source "$(conda info --base)/etc/profile.d/conda.sh"
423-
conda activate vortex_v04
424-
python -c "import sys; print(sys.executable)" # must be .../envs/vortex_v04/...
423+
conda activate vortex_v1
424+
python -c "import sys; print(sys.executable)" # must be .../envs/vortex_v1/...
425425
```
426426
427427
If `conda activate` isn't usable in the current shell, prefix
428-
each python invocation with `conda run -n vortex_v04` instead.
428+
each python invocation with `conda run -n vortex_v1` instead.
429429
A wrong-env python will fail to import the framework's C
430430
extension and every pre-flight / benchmark call below will error.
431431
@@ -1074,7 +1074,7 @@
10741074
```bash
10751075
CONDA_BASE=$(conda info --base 2>/dev/null || echo /root/anaconda3)
10761076
source "$CONDA_BASE/etc/profile.d/conda.sh"
1077-
conda activate vortex_v04
1077+
conda activate vortex_v1
10781078
python -c "import sys; print(sys.executable)"
10791079
```
10801080

docker/Dockerfile.pd

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# vortex_torch + sglang v0.5.9, with PD (prefill/decode) disaggregation.
2+
#
3+
# Builds the v0.5 branch (vendored sglang lives at
4+
# third_party/sglang/v0.5.9/sglang) and adds the RDMA/InfiniBand userspace
5+
# stack + the Mooncake transfer engine that sglang's disaggregation backend
6+
# needs to move KV cache between prefill and decode workers.
7+
#
8+
# Build:
9+
# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.pd \
10+
# --build-arg VORTEX_TORCH_REF=v0.5 \
11+
# --build-arg TORCH_CUDA_ARCH_LIST="9.0" \
12+
# -t vortex-torch:pd-0.5.9 .
13+
# # Blackwell (B200/sm100): pass TORCH_CUDA_ARCH_LIST="10.0".
14+
#
15+
# Run (needs RDMA devices + IPC/host net for the transfer engine), e.g.:
16+
# docker run --gpus all --ipc=host --network=host \
17+
# --device=/dev/infiniband --cap-add=IPC_LOCK \
18+
# -v /raid/catalyst/models:/models -e HF_HOME=/models \
19+
# -it vortex-torch:pd-0.5.9
20+
# # then: bash marks/pd/run_p1d1.sh (mooncake backend, --disaggregation-ib-device mlx5_0)
21+
22+
# CUDA 12.9 matches sglang 0.5.9's pins (cuda-python==12.9, torch==2.9.1).
23+
ARG CUDA_VERSION=12.9.1
24+
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu24.04
25+
26+
ENV DEBIAN_FRONTEND=noninteractive
27+
28+
# Hopper=9.0, Blackwell=10.0. Override at build time as needed.
29+
ARG TORCH_CUDA_ARCH_LIST="9.0"
30+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}"
31+
32+
SHELL ["/bin/bash", "-c"]
33+
34+
# --- system deps: toolchain + libnuma + RDMA/InfiniBand userspace (for Mooncake) ---
35+
RUN apt-get update && apt-get install -y --no-install-recommends \
36+
python3 \
37+
python3-dev \
38+
python3-pip \
39+
python3-venv \
40+
build-essential \
41+
ca-certificates \
42+
cmake \
43+
curl \
44+
git \
45+
ninja-build \
46+
wget \
47+
libnuma1 \
48+
libnuma-dev \
49+
# InfiniBand / RDMA userspace — required by the Mooncake transfer engine
50+
rdma-core \
51+
libibverbs-dev \
52+
libibverbs1 \
53+
libibumad3 \
54+
librdmacm1 \
55+
ibverbs-providers \
56+
infiniband-diags \
57+
perftest \
58+
&& ln -sf /usr/bin/python3 /usr/bin/python \
59+
&& rm -rf /var/lib/apt/lists/*
60+
61+
# Isolated venv (Ubuntu 24.04 is PEP-668 externally-managed).
62+
RUN python3 -m venv /opt/venv
63+
ENV PATH="/opt/venv/bin:${PATH}"
64+
ENV VIRTUAL_ENV="/opt/venv"
65+
RUN python -m pip install --upgrade pip setuptools wheel
66+
67+
ARG VORTEX_TORCH_REF=v0.5
68+
ARG MOONCAKE_VERSION=0.3.9
69+
70+
WORKDIR /workspace
71+
72+
# The PD-disaggregation support (server-args overlap-schedule force, the
73+
# decode-side rebuild_aux hook, the page-major get_contiguous_buf_infos) lives
74+
# in the vendored sglang on this branch — make sure VORTEX_TORCH_REF is pushed.
75+
RUN git clone -b "${VORTEX_TORCH_REF}" --recursive \
76+
https://github.com/Infini-AI-Lab/vortex_torch.git
77+
78+
# --- sglang v0.5.9 (vendored): editable install of its python package ---
79+
# (v0.5.9 has no install.sh; pulls torch==2.9.1, flashinfer==0.6.3,
80+
# sgl-kernel==0.3.21 as wheels.)
81+
WORKDIR /workspace/vortex_torch/third_party/sglang/v0.5.9/sglang
82+
RUN pip install --no-cache-dir -e "python"
83+
84+
# --- vortex_torch (pure Python + Triton JIT; no compiled C extension) ---
85+
WORKDIR /workspace/vortex_torch
86+
RUN pip install --no-cache-dir -e .
87+
88+
# --- Mooncake transfer engine (KV transport for PD disaggregation) ---
89+
# CUDA 12.x → pip wheel. (CUDA>=13 would need a from-source build.)
90+
RUN pip install --no-cache-dir "mooncake-transfer-engine==${MOONCAKE_VERSION}"
91+
92+
# --- sanity checks ---
93+
RUN which python && python --version && which pip && pip --version
94+
RUN python - <<'PY'
95+
import ctypes
96+
for lib in ("libnuma.so.1", "libibverbs.so.1", "librdmacm.so.1"):
97+
ctypes.CDLL(lib)
98+
print(f"OK: {lib} loaded")
99+
import sglang, vortex_torch
100+
print("OK: import sglang", getattr(sglang, "__version__", "?"))
101+
print("OK: import vortex_torch")
102+
from mooncake.engine import TransferEngine # mooncake transport entrypoint
103+
print("OK: mooncake TransferEngine importable")
104+
PY
105+
106+
CMD ["/bin/bash"]

examples/algo3.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ models=(
1010
MiniMaxAI/MiniMax-M2.7
1111
)
1212
trials=(
13-
32
13+
16
1414
)
1515
topk_val=(
16-
125
16+
61
1717
)
1818
for algo in "${sparse_algos[@]}"; do
1919
for model in "${models[@]}"; do
@@ -23,9 +23,9 @@ for algo in "${sparse_algos[@]}"; do
2323
python examples/verify_algo.py \
2424
--trials ${trial} \
2525
--topk-val ${k_val} \
26-
--page-size 16 \
26+
--page-size 32 \
2727
--workload-chunk-size 64 \
28-
--block-size 16 \
28+
--block-size 32 \
2929
--topk-ratio 0.00 \
3030
--vortex-module-name "${algo}" \
3131
--model-name "${model}" \
@@ -37,7 +37,7 @@ for algo in "${sparse_algos[@]}"; do
3737
--vortex-attention-backend trtllm \
3838
--vortex-impl-backend triton \
3939
--vortex-use-tensor-core \
40-
--vortex-layers-skip 0 \
40+
--vortex-layers-skip \
4141
--summary-dir summary-MiniMax-M2.7-sglang-trtllm \
4242
--skip-already-finished-check
4343
done

third_party/sglang/v0.5.9/sglang/python/sglang/srt/layers/attention/attention_registry.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def create_flashinfer_backend(runner):
5050
runner, init_new_workspace=runner.init_new_workspace
5151
)
5252
else:
53+
# MLA + vortex sparsity is wired on the trtllm_mla backend (see
54+
# create_trtllm_mla_backend) so prefill dispatches through the
55+
# materialized MHA path. The flashinfer MLA path stays dense-only.
5356
from sglang.srt.layers.attention.flashinfer_mla_backend import (
5457
FlashInferMLAAttnBackend,
5558
)
@@ -61,6 +64,13 @@ def create_flashinfer_backend(runner):
6164
def create_trtllm_mla_backend(runner):
6265
if not runner.use_mla_backend:
6366
raise ValueError("trtllm_mla backend can only be used with MLA models.")
67+
if runner.server_args.enable_vortex_sparsity:
68+
# MLA + vortex sparsity. Use trtllm_mla (not flashinfer) so the model
69+
# dispatches PREFILL through the materialized MHA path (192/128) — same
70+
# as the dense baseline — and DECODE through absorb (the sparse path).
71+
from vortex_torch.engine.sgl.attention_backend import VortexTRTLLMMLABackend
72+
73+
return VortexTRTLLMMLABackend(runner)
6474
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
6575

6676
return TRTLLMMLABackend(runner)
@@ -108,6 +118,13 @@ def create_triton_backend(runner):
108118
)
109119

110120
return DoubleSparseAttnBackend(runner)
121+
elif runner.use_mla_backend and runner.server_args.enable_vortex_sparsity:
122+
# MLA + vortex on the Triton decode kernel (not geometry-locked like
123+
# trtllm_mla; handles GLM-4.7-Flash's qk_nope=192 / v_head=256). Prefill
124+
# + skipped-layer decode delegate to the dense TritonAttnBackend.
125+
from vortex_torch.engine.sgl.attention_backend import VortexTritonMLABackend
126+
127+
return VortexTritonMLABackend(runner)
111128
else:
112129
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
113130

third_party/sglang/v0.5.9/sglang/python/sglang/srt/model_executor/model_runner.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -594,13 +594,24 @@ def initialize(self, min_per_gpu_memory: float):
594594
self.server_args.vortex_module_name,
595595
user_file=self.server_args.vortex_module_path
596596
)
597-
self.sparse_attention.initialize(
598-
block_size=self.block_size,
599-
head_dim=self.model_config.head_dim,
600-
kv_cache_dtype=self.kv_cache_dtype,
601-
q_data_type=self.dtype,
602-
intermediate_dtype=self.server_args.vortex_dtype,
603-
)
597+
if isinstance(self.sparse_attention, vortex_torch.flow.vFlowMLA):
598+
# MLA flow: latent geometry instead of a single head_dim.
599+
self.sparse_attention.initialize(
600+
block_size=self.block_size,
601+
kv_lora_rank=self.model_config.kv_lora_rank,
602+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
603+
kv_cache_dtype=self.kv_cache_dtype,
604+
q_data_type=self.dtype,
605+
intermediate_dtype=self.server_args.vortex_dtype,
606+
)
607+
else:
608+
self.sparse_attention.initialize(
609+
block_size=self.block_size,
610+
head_dim=self.model_config.head_dim,
611+
kv_cache_dtype=self.kv_cache_dtype,
612+
q_data_type=self.dtype,
613+
intermediate_dtype=self.server_args.vortex_dtype,
614+
)
604615

605616
# Init memory pool and attention backends
606617
self.init_memory_pool(min_per_gpu_memory)

third_party/sglang/v0.5.9/sglang/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,12 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int):
546546
end_layer=self.end_layer,
547547
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
548548
)
549-
elif self.use_mla_backend and not self.mambaish_config:
549+
elif (
550+
self.use_mla_backend
551+
and not self.mambaish_config
552+
and not self.server_args.enable_vortex_sparsity
553+
):
554+
# vortex+MLA falls through to the VortexMLACachePool branch below.
550555
assert not is_nsa_model
551556
if is_float4_e2m1fn_x2(self.kv_cache_dtype):
552557
self.token_to_kv_pool = MLATokenToKVPoolFP4(
@@ -588,8 +593,25 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int):
588593
start_layer=self.start_layer,
589594
end_layer=self.end_layer,
590595
)
596+
elif self.server_args.enable_vortex_sparsity and self.use_mla_backend:
597+
# MLA: single fused latent pool (kv_c | k_pe), no per-head K/V.
598+
from vortex_torch.engine.sgl.memory_pool_mla import VortexMLACachePool
599+
self.token_to_kv_pool = VortexMLACachePool(
600+
self.max_total_num_tokens,
601+
page_size=self.page_size,
602+
dtype=self.kv_cache_dtype,
603+
kv_lora_rank=self.model_config.kv_lora_rank,
604+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
605+
layer_num=self.num_effective_layers,
606+
device=self.device,
607+
enable_memory_saver=self.server_args.enable_memory_saver,
608+
sparse_attention=self.sparse_attention,
609+
model_runner=self,
610+
start_layer=self.start_layer,
611+
end_layer=self.end_layer,
612+
)
591613
elif self.server_args.enable_vortex_sparsity:
592-
614+
593615
from vortex_torch.engine.sgl.memory_pool import VortexCachePool
594616
self.token_to_kv_pool = VortexCachePool(
595617
self.max_total_num_tokens,

third_party/sglang/v0.5.9/sglang/python/sglang/srt/server_args.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,6 +2565,24 @@ def _handle_load_format(self):
25652565
)
25662566

25672567
def _handle_pd_disaggregation(self):
2568+
# Vortex (PD disagg Option B) rebuilds its auxiliary cache on the decode
2569+
# side from the transferred K/V. That rebuild shares the cache-side
2570+
# compiled scratch with the live decode forward; the overlap scheduler
2571+
# would run them on different streams concurrently and corrupt it. Force
2572+
# overlap off on any vortex PD server — mirrors the non-PD engine helper
2573+
# in vortex_torch/engine/sgl/api.py, which hardcodes the same.
2574+
if self.enable_vortex_sparsity and self.disaggregation_mode in (
2575+
"prefill",
2576+
"decode",
2577+
):
2578+
if not self.disable_overlap_schedule:
2579+
self.disable_overlap_schedule = True
2580+
logger.warning(
2581+
"Overlap schedule is disabled for vortex sparsity under PD "
2582+
"disaggregation (the decode-side aux rebuild shares cache "
2583+
"scratch with the decode forward)."
2584+
)
2585+
25682586
if self.disaggregation_mode == "decode":
25692587
assert (
25702588
self.disaggregation_decode_tp is None

0 commit comments

Comments
 (0)