Skip to content

Qwen3Next MTP for vLLM plugin mode#772

Open
ganyi1996ppo wants to merge 9 commits into
mainfrom
ganyi/qwen3next_mtp
Open

Qwen3Next MTP for vLLM plugin mode#772
ganyi1996ppo wants to merge 9 commits into
mainfrom
ganyi/qwen3next_mtp

Conversation

@ganyi1996ppo
Copy link
Copy Markdown
Contributor

@ganyi1996ppo ganyi1996ppo commented May 13, 2026

Motivation

server script:

export VLLM_TORCH_PROFILER_DIR=./vllm_profile
export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1
export AITER_QUICK_REDUCE_QUANTIZATION=INT4
export HIP_VISIBLE_DEVICES=0,1,2,3
export ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=0
export ATOM_USE_CUSTOM_ALL_GATHER=0
export ATOM_DISABLE_VLLM_PLUGIN=0
MODEL=/mnt/data/pretrained_model/Qwen/Qwen3-Next-80B-A3B-Instruct-FP8


vllm serve $MODEL\
  --port 8200 \
  --no-enable-prefix-caching \
  --tensor-parallel-size 1 \
  --gpu_memory_utilization 0.8 \
  --max-num-batched-tokens 32768 \
  --kv-cache-dtype fp8 \
  --compilation-config '{ "cudagraph_mode": "FULL_AND_PIECEWISE"}' \
  --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile", "torch_profiler_with_stack": "True"}' \
  --speculative-config '{"num_speculative_tokens":1, "method": "mtp"}'\

verify script

MODEL_ID=/mnt/data/pretrained_model/Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
lm_eval \
  --model local-completions \
  --model_args model=$MODEL_ID,base_url=http://localhost:8200/v1/completions,num_concurrent=256,max_retries=10,timeout=3000,seed=1234,max_gen_toks=2048,temperature=0,tokenized_requests=False,trust_remote_code=True \
  --batch_size auto \
  --tasks gsm8k \
  --num_fewshot 5 \

result

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9212|±  |0.0074|
|     |       |strict-match    |     5|exact_match|↑  |0.9143|±  |0.0077|

SpecDecoding metrics: Mean acceptance length: 1.91, Accepted throughput: 371.69 tokens/s, Drafted throughput: 410.38 tokens/s, Accepted: 3717 tokens, Drafted: 4104 tokens, Per-position acceptance rate: 0.906, Avg Draft acceptance rate: 90.6%

Technical Details

Test Plan

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings May 13, 2026 08:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for running Qwen3Next MTP (multi-token prediction / EAGLE-style speculative decoding) under vLLM plugin mode, including draft-model construction, KV-cache indexing fixes, and attention/metadata handling for multi-token verification.

Changes:

  • Register Qwen3NextMTP for vLLM plugin mode and add model-class routing to the ATOM vLLM wrapper.
  • Teach the vLLM wrapper to detect draft-model construction, load draft weights correctly (spec_decode=True), and swap the global atom_config during forward() to keep layer lookups consistent across target/draft alternation.
  • Update plugin attention metadata + paged attention implementations to correctly handle multi-token decode layouts used by MTP/EAGLE.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
atom/plugin/vllm/register.py Registers Qwen3NextMTP architecture override for vLLM plugin mode.
atom/plugin/vllm/model_wrapper.py Detects draft vs target, routes draft architecture, swaps/restores global atom_config for forwards, and passes spec_decode into weight loading.
atom/plugin/vllm/attention_backend/attention_gdn.py Fixes GDN attention output writeback for speculative decode and adjusts imports/code paths.
atom/plugin/attention.py Adjusts attention metadata builder thresholds/logic for MTP/EAGLE multi-token verification and async spec-decode metadata.
atom/plugin/attention_mha.py Updates paged-attention decode kernels and buffer sizing to support MTP multi-token decode layout; fixes extend block-table slicing.
atom/models/qwen3_next.py Adds explicit layer_num for attention KV slot isolation in MTP, fixes speculative_config fallback for vLLM, and exposes embed_tokens for sharing.
atom/models/qwen3_next_mtp.py Implements Qwen3Next MTP draft model with correct layer indexing, quant prefixing, and expert mapping for shared-expert fusion.
atom/model_loader/loader.py Plumbs spec_decode through plugin-mode loading so draft models can load mtp.* weights and apply MTP remapping.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/plugin/vllm/attention_backend/attention_gdn.py Outdated
Comment thread atom/plugin/vllm/attention_backend/attention_gdn.py Outdated
Comment thread atom/plugin/vllm/attention_backend/attention_gdn.py Outdated
Comment thread atom/plugin/vllm/model_wrapper.py Outdated
Comment thread atom/models/qwen3_next.py Outdated
Comment thread atom/models/qwen3_next.py Outdated
Copilot AI review requested due to automatic review settings May 13, 2026 09:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.

Comment thread atom/models/qwen3_next.py Outdated
# `kernel_size-1 + num_spec` rows per slot and the extra row spilled
# into the page-adjacent ssm_state, corrupting layer 0's recurrent
# state. Pull the spec config from the vLLM config as a fallback.
if is_vllm() and self.speculative_config is None:
Signed-off-by: ganyi <ygan@amd.com>
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3next_mtp branch from f38481f to 3af7ccb Compare May 14, 2026 08:01
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 14, 2026 08:05
@zejunchen-zejun zejunchen-zejun requested a review from whx-sjtu May 14, 2026 08:05
Signed-off-by: ganyi <ygan@amd.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (1)

atom/plugin/vllm/model_wrapper.py:421

  • The draft-model detection in load_weights checks for "Qwen3NextMTP", but other parts of the repo still use the architecture key "Qwen3NextMTPModel" for Qwen3-Next MTP. If the draft model’s HF config reports "Qwen3NextMTPModel", spec_decode-specific loading (hf_config_override / weight filtering) won’t activate. Please align this set with the actual HF architecture string used for the draft model.
        is_mtp_draft_model = self.model_arch in {
            "DeepSeekMTPModel",
            "Qwen3NextMTP",
        }

Comment thread atom/plugin/vllm/register.py
Comment thread atom/plugin/vllm/model_wrapper.py
Comment on lines 153 to +161
main_model_arch = vllm_config.model_config.architectures[0]
model_arch = _select_model_arch(vllm_config)
self.is_mtp_draft_model = self.is_mtp and model_arch != main_model_arch
if self.is_mtp_draft_model:
self.atom_config = get_current_atom_config()
else:
self.atom_config = generate_atom_config_for_plugin_mode(vllm_config)
self.model_arch = model_arch
_prepare_env(atom_config=self.atom_config)
Signed-off-by: ganyi <ygan@amd.com>
@zejunchen-zejun
Copy link
Copy Markdown
Collaborator

Hi @ganyi1996ppo
Could you help add qwen3next MTP into atom-vllm nightly and benchmark workflow, so that the acc and perf can be tracked when merged?

Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 14, 2026 13:55
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (2)

atom/plugin/vllm/model_wrapper.py:196

  • _expose_spec_decode_attrs() is now only executed when model_arch in _MTP_MASK_INPUT_ARCH (currently only DeepSeekMTPModel). The new Qwen3NextMTP model has the same extra .model nesting and does not expose embed_tokens/layers on the outer module, so vLLM speculative decoding weight/embedding sharing is likely to fail. Suggest calling _expose_spec_decode_attrs() for all MTP draft models that wrap an inner .model (and keep _adapt_mtp_layers_for_vllm() gated separately if it’s DeepSeek-specific), or add Qwen3NextMTP to the relevant allowlist.
        logger.info(f"Construct ATOM model {model_arch} for vLLM plugin mode")
        self.model = model_cls(self.atom_config)

        if model_arch in _MTP_MASK_INPUT_ARCH:
            self._adapt_mtp_layers_for_vllm()
            # Mirror nested attributes required by vLLM speculative decoding.
            self._expose_spec_decode_attrs()

atom/plugin/vllm/model_wrapper.py:422

  • Draft-model detection only checks self.model_arch against { "DeepSeekMTPModel", "Qwen3NextMTP" }. If the HF draft config still reports Qwen3NextMTPModel (as referenced elsewhere in the repo), this branch won’t treat it as spec-decode, and hf_config_override won’t be applied. Consider accepting both Qwen3NextMTP and Qwen3NextMTPModel here (and in _ATOM_MODEL_CLASSES) so both draft-arch spellings work.
        is_mtp_draft_model = self.model_arch in {
            "DeepSeekMTPModel",
            "Qwen3NextMTP",
        }

Comment on lines 146 to +160
self.vllm_config = vllm_config
self.atom_config = generate_atom_config_for_plugin_mode(vllm_config)
self.is_mtp = False
speculative_config = getattr(vllm_config, "speculative_config", None)
if speculative_config is not None:
spec_method = speculative_config.method
self.is_mtp = spec_method == "mtp"

_prepare_env(atom_config=self.atom_config)

main_model_arch = vllm_config.model_config.architectures[0]
model_arch = _select_model_arch(vllm_config)
self.is_mtp_draft_model = self.is_mtp and model_arch != main_model_arch
if self.is_mtp_draft_model:
self.atom_config = get_current_atom_config()
else:
self.atom_config = generate_atom_config_for_plugin_mode(vllm_config)
"GlmMoeDsaForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
"DeepSeekMTPModel": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
"Qwen3NextForCausalLM": "atom.models.qwen3_next:Qwen3NextForCausalLMVllm",
"Qwen3NextMTP": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 14, 2026 14:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

recipes/atom_vllm/Qwen3.5.md:137

  • The "Key Environment Variables" list no longer includes ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1, but the earlier text still refers to three required variables. Please ensure this section stays consistent with the intended required/optional env var set for Qwen3.5.

## Key Environment Variables

- `ATOM_USE_CUSTOM_ALL_GATHER=0`: **Required** - disables custom all-gather for compatibility with Qwen3.5 model architecture
- `AITER_QUICK_REDUCE_QUANTIZATION=INT4`: **Performance optimization** - enables INT4 quantization for quick reduce operations
  - **Benefit**: Significantly improves TTFT (Time To First Token) performance by reducing communication overhead during tensor parallelism all-reduce operations

self._expose_spec_decode_attrs()

if model_arch in _MTP_MASK_INPUT_ARCH:
self._adapt_mtp_layers_for_vllm()
Comment on lines 72 to 76
**Important**: The following three environment variables are required for Qwen3.5:

- `ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1`: Disables ATOM attention plugin to use vLLM's implementation for full attention layers (required because Qwen3.5 uses a hybrid architecture with both linear attention (GatedDeltaNet) and full attention layers)
- `ATOM_USE_CUSTOM_ALL_GATHER=0`: Disables custom all-gather for compatibility with Qwen3.5 model architecture
- `AITER_QUICK_REDUCE_QUANTIZATION=INT4`: **Performance optimization** - enables INT4 quantization for quick reduce operations, which can significantly improve TTFT (Time To First Token) performance. **Note**: This optimization may introduce a risk of accuracy degradation. For accuracy-critical workloads, consider validating with your specific use case.

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 14, 2026 15:22
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (1)

atom/plugin/vllm/model_wrapper.py:432

  • atom.config.SpeculativeConfig does not expose draft_model_config, so draft_model_config = getattr(self.atom_config.speculative_config, "draft_model_config", None) will always be None and hf_config_override will not be applied for MTP draft-model weight loading. This can cause the draft model to load with the target model's HF config. Use self.atom_config.speculative_config.draft_model_hf_config (or fall back to self.vllm_config.speculative_config.draft_model_config.hf_config) when building draft_hf_config.
        is_mtp_draft_model = self.model_arch in {
            "DeepSeekMTPModel",
            "Qwen3NextMTP",
        }
        draft_hf_config = None
        if is_mtp_draft_model:
            draft_model_config = getattr(
                getattr(self.atom_config, "speculative_config", None),
                "draft_model_config",
                None,
            )
            if draft_model_config is not None:
                draft_hf_config = getattr(
                    draft_model_config, "hf_config", draft_model_config
                )

Comment on lines +194 to 196
# Mirror nested attributes required by vLLM speculative decoding.
self._expose_spec_decode_attrs()

**Important**: The following three environment variables are required for Qwen3.5:

- `ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1`: Disables ATOM attention plugin to use vLLM's implementation for full attention layers (required because Qwen3.5 uses a hybrid architecture with both linear attention (GatedDeltaNet) and full attention layers)
- `ATOM_USE_CUSTOM_ALL_GATHER=0`: Disables custom all-gather for compatibility with Qwen3.5 model architecture
Comment thread atom/plugin/config.py
Comment on lines +75 to +112
def _build_atom_speculative_config_from_vllm(vllm_spec_config: Any):
"""Translate vLLM's SpeculativeConfig into ATOM's SpeculativeConfig.

Reuses vLLM's already-loaded draft hf_config (skips a second disk fetch
in ATOM SpeculativeConfig.__post_init__) but still runs ATOM's
hf_config_override on it — so MTP model_type remap, n_routed_experts
backfill (Qwen families), and architecture rewrite all land on the
draft config in one place. Mirrors how standalone ATOM MTP exposes
the draft hf_config via atom_config.speculative_config.

The draft hf_config is deepcopied first because hf_config_override
mutates `architectures` to ATOM's standalone naming (e.g.
"Qwen3NextMTPModel"), which differs from vLLM's registry name
("Qwen3NextMTP"). Mutating in place would make vLLM's later draft
architecture lookup fail.
"""
if vllm_spec_config is None:
return None

from atom.config import SpeculativeConfig

draft_model_config = getattr(vllm_spec_config, "draft_model_config", None)
draft_hf_config = getattr(draft_model_config, "hf_config", None)
if draft_hf_config is not None:
draft_hf_config = copy.deepcopy(draft_hf_config)
model_path = getattr(draft_model_config, "model", None) or getattr(
vllm_spec_config, "model", None
)

return SpeculativeConfig(
method=getattr(vllm_spec_config, "method", "") or "",
model=model_path,
num_speculative_tokens=getattr(
vllm_spec_config, "num_speculative_tokens", None
),
draft_model_hf_config=draft_hf_config,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants