diff --git a/.github/benchmark/oot_benchmark_models.json b/.github/benchmark/oot_benchmark_models.json index 1b54c6265..f06a8e3c3 100644 --- a/.github/benchmark/oot_benchmark_models.json +++ b/.github/benchmark/oot_benchmark_models.json @@ -181,7 +181,7 @@ "1024x8192" ], "extra_args": "--trust-remote-code --tensor-parallel-size 4 --attention-backend ROCM_AITER_FA --gpu-memory-utilization 0.8 --max-num-batched-tokens 16384 --max-model-len 16384", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0" + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" }, { "tp_size": 8, @@ -192,7 +192,7 @@ "1024x8192" ], "extra_args": "--trust-remote-code --tensor-parallel-size 8 --attention-backend ROCM_AITER_FA --gpu-memory-utilization 0.8 --max-num-batched-tokens 16384 --max-model-len 16384", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0" + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" } ] }, @@ -213,7 +213,7 @@ "1024x8192" ], "extra_args": "--trust-remote-code --tensor-parallel-size 8 --attention-backend ROCM_AITER_FA --gpu-memory-utilization 0.8 --max-num-batched-tokens 16384 --max-model-len 16384", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0" + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" } ] }, @@ -231,7 +231,7 @@ "prefix": "qwen3-next-80b-a3b-instruct-fp8-tp1-met", "bench_args": "", "extra_args": "--trust-remote-code --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-model-len 16384", - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0" + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" }, { "tp_size": 4, @@ -240,7 +240,7 @@ "prefix": "qwen3-next-80b-a3b-instruct-fp8-tp4-met", "bench_args": "", "extra_args": "--trust-remote-code --tensor-parallel-size 4 --max-num-batched-tokens 32768 --max-model-len 16384", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0" + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" }, { "tp_size": 1, @@ -249,7 +249,7 @@ "prefix": "qwen3-next-80b-a3b-instruct-fp8-aw-tp1", "bench_args": "", "extra_args": "--trust-remote-code --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-model-len 16384", - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0" + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" }, { "tp_size": 2, @@ -258,7 +258,7 @@ "prefix": "qwen3-next-80b-a3b-instruct-fp8-aw-tp2", "bench_args": "", "extra_args": "--trust-remote-code --tensor-parallel-size 2 --max-num-batched-tokens 32768 --max-model-len 16384", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0" + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" }, { "tp_size": 4, @@ -267,7 +267,25 @@ "prefix": "qwen3-next-80b-a3b-instruct-fp8-aw-tp4", "bench_args": "", "extra_args": "--trust-remote-code --tensor-parallel-size 4 --max-num-batched-tokens 32768 --max-model-len 16384", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0" + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" + }, + { + "tp_size": 1, + "display": "Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW)", + "dashboard_model": "Qwen3-Next-80B-A3B-Instruct-FP8-mtp-tp1", + "prefix": "qwen3-next-80b-a3b-instruct-fp8-mtp-tp1-aw", + "bench_args": "", + "extra_args": "--trust-remote-code --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-model-len 16384 --speculative-config '{\"num_speculative_tokens\":1, \"method\": \"mtp\"}'", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" + }, + { + "tp_size": 4, + "display": "Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW)", + "dashboard_model": "Qwen3-Next-80B-A3B-Instruct-FP8-mtp-tp4", + "prefix": "qwen3-next-80b-a3b-instruct-fp8-mtp-tp4-aw", + "bench_args": "", + "extra_args": "--trust-remote-code --tensor-parallel-size 4 --max-num-batched-tokens 32768 --max-model-len 16384 --speculative-config '{\"num_speculative_tokens\":1, \"method\": \"mtp\"}'", + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0" } ] }, diff --git a/.github/benchmark/oot_models_accuracy.json b/.github/benchmark/oot_models_accuracy.json index 1050e3f85..226c374ed 100644 --- a/.github/benchmark/oot_models_accuracy.json +++ b/.github/benchmark/oot_models_accuracy.json @@ -3,7 +3,7 @@ "model_name": "Qwen3-235B-A22B-Instruct-2507-FP8 TP8+EP8", "model_path": "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8", "extraArgs": "--tensor-parallel-size 8 --enable-expert-parallel", - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-8", "test_level": "nightly", "accuracy_threshold": 0.87, @@ -14,7 +14,7 @@ "model_name": "Qwen3-Next-80B-A3B-Instruct-FP8 TP4", "model_path": "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8", "extraArgs": "--tensor-parallel-size 4 --attention-backend ROCM_AITER_FA", - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-4", "test_level": "nightly", "accuracy_threshold": 0.76, @@ -25,7 +25,7 @@ "model_name": "Qwen3.5-397B-A17B-FP8 TP8", "model_path": "Qwen/Qwen3.5-397B-A17B-FP8", "extraArgs": "--tensor-parallel-size 8 --attention-backend ROCM_AITER_FA", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0", + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-8", "test_level": "nightly", "accuracy_threshold": 0.83, @@ -36,7 +36,7 @@ "model_name": "Qwen3.5-397B-A17B TP8", "model_path": "Qwen/Qwen3.5-397B-A17B", "extraArgs": "--tensor-parallel-size 8 --attention-backend ROCM_AITER_FA", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0", + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-8", "test_level": "nightly", "accuracy_threshold": 0.83, @@ -47,7 +47,7 @@ "model_name": "Qwen3.5-397B-A17B-MXFP4 TP4", "model_path": "amd/Qwen3.5-397B-A17B-MXFP4", "extraArgs": "--tensor-parallel-size 4 --attention-backend ROCM_AITER_FA", - "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0", + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-4", "test_level": "nightly", "accuracy_threshold": 0.82, @@ -55,6 +55,18 @@ "accuracy_baseline_model": "Qwen/Qwen3-235B-A22B-Instruct-2507", "_baseline_note": "Using Qwen3-235B baseline as proxy; needs CI measurement for Qwen3.5 specific baseline" }, + { + "model_name": "Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4", + "model_path": "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8", + "extraArgs": "--tensor-parallel-size 4 --speculative-config '{\"num_speculative_tokens\":1, \"method\": \"mtp\"}'", + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", + "runner": "linux-atom-mi35x-4", + "test_level": "nightly", + "accuracy_threshold": 0.8, + "accuracy_baseline": 0.81, + "accuracy_baseline_model": "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8", + "_baseline_note": "Qwen3-Next-80B-A3B-Instruct-FP8 baseline with TP4 (no MTP) as proxy; needs CI measurement for MTP-specific baseline" + }, { "model_name": "Llama-3.1-8B-Instruct TP1", "model_path": "meta-llama/Llama-3.1-8B-Instruct", @@ -157,7 +169,7 @@ "runner": "linux-atom-mi35x-1", "test_level": "nightly", "accuracy_threshold": 0.88, - "accuracy_baseline": 0.90, + "accuracy_baseline": 0.9, "accuracy_baseline_model": "openai/gpt-oss-120b" }, { @@ -169,7 +181,7 @@ "runner": "linux-atom-mi35x-4", "test_level": "nightly", "accuracy_threshold": 0.88, - "accuracy_baseline": 0.90, + "accuracy_baseline": 0.9, "accuracy_baseline_model": "openai/gpt-oss-120b" }, { diff --git a/.github/workflows/atom-vllm-accuracy-validation.yaml b/.github/workflows/atom-vllm-accuracy-validation.yaml index 0b2d7e7a0..dc448212e 100644 --- a/.github/workflows/atom-vllm-accuracy-validation.yaml +++ b/.github/workflows/atom-vllm-accuracy-validation.yaml @@ -24,6 +24,11 @@ on: required: false type: boolean default: false + run_qwen3_next_80b_mtp_tp4: + description: "Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4" + required: false + type: boolean + default: false run_qwen35_397b_fp8_tp8: description: "Qwen3.5-397B-A17B-FP8 TP8" required: false @@ -137,6 +142,7 @@ jobs: RUN_QWEN3_MOE_TP8: ${{ inputs.run_qwen3_moe_tp8 }} RUN_QWEN3_NEXT_80B_TP1: ${{ inputs.run_qwen3_next_80b_tp1 }} RUN_QWEN3_NEXT_80B_TP4: ${{ inputs.run_qwen3_next_80b_tp4 }} + RUN_QWEN3_NEXT_80B_MTP_TP4: ${{ inputs.run_qwen3_next_80b_mtp_tp4 }} RUN_QWEN35_397B_FP8_TP8: ${{ inputs.run_qwen35_397b_fp8_tp8 }} RUN_QWEN35_397B_TP8: ${{ inputs.run_qwen35_397b_tp8 }} RUN_QWEN35_397B_FP4_TP4: ${{ inputs.run_qwen35_397b_fp4_tp4 }} @@ -169,7 +175,7 @@ jobs: "model_path": "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8", "extra_args": "--tensor-parallel-size 8 --enable-expert-parallel", "accuracy_test_threshold": 0.87, - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-8", }, { @@ -178,7 +184,7 @@ jobs: "model_path": "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8", "extra_args": "--tensor-parallel-size 1", "accuracy_test_threshold": 0.83, - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=1", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=1\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-1", }, { @@ -187,7 +193,16 @@ jobs: "model_path": "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8", "extra_args": "--tensor-parallel-size 4", "accuracy_test_threshold": 0.83, - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=1", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_USE_FLYDSL_GDR=1\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", + "runner": "linux-atom-mi35x-4", + }, + { + "toggle_env": "RUN_QWEN3_NEXT_80B_MTP_TP4", + "model_name": "Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4", + "model_path": "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8", + "extra_args": "--tensor-parallel-size 4 --speculative-config '{\"num_speculative_tokens\":1, \"method\": \"mtp\"}'", + "accuracy_test_threshold": 0.80, + "env_vars": "AITER_QUICK_REDUCE_QUANTIZATION=INT4\nATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-4", }, { @@ -196,7 +211,7 @@ jobs: "model_path": "Qwen/Qwen3.5-397B-A17B-FP8", "extra_args": "--tensor-parallel-size 8", "accuracy_test_threshold": 0.83, - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-8", }, { @@ -205,7 +220,7 @@ jobs: "model_path": "Qwen/Qwen3.5-397B-A17B", "extra_args": "--tensor-parallel-size 8", "accuracy_test_threshold": 0.83, - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-8", }, { @@ -214,7 +229,7 @@ jobs: "model_path": "amd/Qwen3.5-397B-A17B-MXFP4", "extra_args": "--tensor-parallel-size 4", "accuracy_test_threshold": 0.83, - "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0", + "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1\nATOM_USE_CUSTOM_ALL_GATHER=0\nATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0", "runner": "linux-atom-mi35x-4", }, { diff --git a/.github/workflows/atom-vllm-benchmark.yaml b/.github/workflows/atom-vllm-benchmark.yaml index a12b91483..d7a178d51 100644 --- a/.github/workflows/atom-vllm-benchmark.yaml +++ b/.github/workflows/atom-vllm-benchmark.yaml @@ -36,6 +36,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) @@ -69,6 +71,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) @@ -102,6 +106,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) @@ -135,6 +141,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) @@ -168,6 +176,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) @@ -201,6 +211,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) @@ -234,6 +246,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) @@ -267,6 +281,8 @@ on: - Qwen3.5-397B-A17B TP8 (OOB) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (MET) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (MET) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP1 (AW) + - Qwen3-Next-80B-A3B-Instruct-FP8-MTP TP4 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP1 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP2 (AW) - Qwen3-Next-80B-A3B-Instruct-FP8 TP4 (AW) diff --git a/atom/models/qwen3_next.py b/atom/models/qwen3_next.py index f8abcb867..40ad2380a 100644 --- a/atom/models/qwen3_next.py +++ b/atom/models/qwen3_next.py @@ -485,7 +485,8 @@ def __init__( self.config = config self.quant_config = quant_config - self.speculative_config = speculative_config + + self.speculative_config = speculative_config or atom_config.speculative_config self.num_spec = ( self.speculative_config.num_speculative_tokens if self.speculative_config @@ -863,7 +864,6 @@ def forward( residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention - if self.input_layernorm.use_fused_quant: if residual is None: residual = hidden_states @@ -1059,6 +1059,11 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight + # Expose embed_tokens at this level for vLLM MTP embedding sharing. + # vLLM's proposer accesses target_wrapper.model.embed_tokens, where + # target_wrapper.model = this class (Qwen3NextForCausalLM). + self.embed_tokens = self.model.embed_tokens + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) @@ -1132,6 +1137,7 @@ def get_mamba_state_shape_from_config( if vllm_config.speculative_config else 0 ) + return MambaStateShapeCalculator.gated_delta_net_state_shape( tp_size, hf_config.linear_num_key_heads, diff --git a/atom/models/qwen3_next_mtp.py b/atom/models/qwen3_next_mtp.py index 2a5f0737e..5c547d955 100644 --- a/atom/models/qwen3_next_mtp.py +++ b/atom/models/qwen3_next_mtp.py @@ -171,9 +171,19 @@ def compute_logits( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) + # Mirror target's get_expert_mapping: when shared-expert fusion is on, + # the loader rewrites `mlp.shared_expert.*` to `mlp.experts.{N}.*` + # (where N == n_routed_experts), so the expert_mapping must include + # an extra slot for that fused shared-expert. Without this, MTP's + # shared_expert weights get silently dropped during loading. + from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled + + n_routed = getattr(self.config, "n_routed_experts", self.config.num_experts) + n_shared = getattr(self.config, "n_shared_experts", 1) return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, + num_experts=n_routed + + (n_shared if is_rocm_aiter_fusion_shared_expert_enabled() else 0), ) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 9c674c1cd..81ded7cfc 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -283,6 +283,28 @@ def init_method_under_plugin_mode( i64_kwargs = {"dtype": torch.int64, "device": device} self.positions = CpuGpuBuffer(max_num_batched_tokens, **i64_kwargs) + # Bump reorder_batch_threshold so multi-token spec-decode requests + # (MTP / EAGLE) are routed through the decode path. Mirrors vLLM's + # AttentionMetadataBuilder._init_reorder_batch_threshold(supports_spec_as_decode=True). + speculative_config = getattr(config, "speculative_config", None) + if ( + getattr(self, "reorder_batch_threshold", None) is not None + and speculative_config is not None + and getattr(speculative_config, "num_speculative_tokens", None) is not None + ): + parallel_drafting = getattr(speculative_config, "parallel_drafting", False) + max_num_queries_for_spec = 1 + (2 if parallel_drafting else 1) * ( + speculative_config.num_speculative_tokens + ) + self.reorder_batch_threshold = max( + self.reorder_batch_threshold, max_num_queries_for_spec + ) + logger.info( + "Spec decode: bumped reorder_batch_threshold to %d (num_spec_tokens=%d)", + self.reorder_batch_threshold, + speculative_config.num_speculative_tokens, + ) + return init_method_under_plugin_mode @@ -300,7 +322,7 @@ def setup_attn_metadata_builder_base_class_and_attributes(class_dict: dict): needs_generic = True # align with vllm rocm aiter fa - class_dict["_cudagraph_support"] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + class_dict["_cudagraph_support"] = AttentionCGSupport.UNIFORM_BATCH class_dict["reorder_batch_threshold"] = 1 return base_class, generic_base, needs_generic, class_dict @@ -324,9 +346,12 @@ def build( from vllm.v1.attention.backends.utils import split_decodes_prefills_and_extends - # here assume the decode num token is 1 per request + # decode_threshold tracks reorder_batch_threshold so MTP/EAGLE + # multi-token verification (query_len > 1) routes through decode. + decode_threshold = getattr(self, "reorder_batch_threshold", 1) or 1 split_ret = split_decodes_prefills_and_extends( - common_attn_metadata=common_attn_metadata, decode_threshold=1 + common_attn_metadata=common_attn_metadata, + decode_threshold=decode_threshold, ) ( @@ -351,6 +376,11 @@ def build( query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] num_computed_tokens_cpu = common_attn_metadata._num_computed_tokens_cpu + # In async spec-decode mode (auto-enabled for MTP/EAGLE), vLLM sets + # _num_computed_tokens_cpu to None because the GPU seq_lens is the + # authoritative source. Reconstruct from CPU tensors we already have. + if num_computed_tokens_cpu is None: + num_computed_tokens_cpu = seq_lens - query_lens_cpu prefill_max_query_len = decode_max_query_len = ( common_attn_metadata.max_query_len diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 16c88949d..100c492ab 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -234,15 +234,27 @@ def paged_attention_triton_plugin_mode( v_cache: torch.Tensor, k_scale: torch.Tensor, v_scale: torch.Tensor, + num_decodes: int, out: torch.Tensor, attn_metadata: "AttentionMetaData", ps: bool = True, ): - o = out - num_seqs, num_q_heads_total, head_size = q.shape + # q.shape[0] == num_decodes * max_query_len for MTP (one row per decode + # token, query_len > 1). For non-MTP it equals num_decodes (query_len = 1). + # pa_decode_gluon handles multi-token causal masking internally when + # `query_length > 1` is passed; intermediate buffers must be sized + # `num_decodes` (not q.shape[0]) and `query_group_size` must include + # the max_qlen multiplier — mirroring server-mode `paged_attention_triton`. + _, num_q_heads_total, head_size = q.shape num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape - query_group_size = num_q_heads_total // num_kv_heads + decode_metadata = attn_metadata.plugin_metadata.decode_metadata + max_qlen = decode_metadata.max_query_len if decode_metadata is not None else 1 assert num_q_heads_total % num_kv_heads == 0 + + seq_lens = attn_metadata.plugin_metadata.seq_lens[:num_decodes] + block_tables = attn_metadata.plugin_metadata.block_table[:num_decodes] + + query_group_size = max_qlen * (num_q_heads_total // num_kv_heads) context_partition_size = 256 # use_ps = self.adopt_persistent_kernel( @@ -250,7 +262,9 @@ def paged_attention_triton_plugin_mode( # ) use_ps = True if use_ps: - max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) + max_context_partition_num = get_recommended_splits( + num_decodes, num_kv_heads + ) else: max_context_partition_num = _NO_PS_FIXED_SPLITS @@ -258,9 +272,8 @@ def paged_attention_triton_plugin_mode( max_context_partition_num = 1 context_partition_size = 128 - # Output buffers (same as Triton) intermediate_shape = ( - num_seqs, + num_decodes, num_kv_heads, max_context_partition_num, query_group_size, @@ -283,21 +296,19 @@ def paged_attention_triton_plugin_mode( k_scale = k_scale.unsqueeze(-1) v_scale = v_scale.unsqueeze(-1) - num_decode_seqs = q.shape[0] - seq_lens_decode = attn_metadata.plugin_metadata.seq_lens[:num_decode_seqs] - block_tables_decode = attn_metadata.plugin_metadata.block_table[ - :num_decode_seqs - ] - + # Kernel takes natural q layout [batch * query_length, num_q_heads, head_size]. + # Internally it derives batch_size = q.shape[0] // query_length and reshapes + # to [batch, query_length, num_kv_heads, group, head_size]. See + # aiter/aiter/ops/triton/gluon/pa_decode_gluon.py:5371-5377 and 5542-5544. torch.ops.aiter.pa_decode_gluon( - o, + out, q, k_cache, v_cache, - seq_lens_decode, - block_tables_decode, + seq_lens, + block_tables, self.scale, - 1, # query_lenth + max_qlen, # query_length — handles multi-token causal mask internally max_context_partition_num, context_partition_size, compute_type, @@ -312,8 +323,7 @@ def paged_attention_triton_plugin_mode( sliding_window=self.sliding_window, ps=use_ps, ) - - return o + return out def paged_attention_asm_plugin_mode( self, @@ -327,6 +337,11 @@ def paged_attention_asm_plugin_mode( attn_metadata: "AttentionMetaData", out: torch.Tensor, ): + decode_metadata = attn_metadata.plugin_metadata.decode_metadata + max_qlen = decode_metadata.max_query_len if decode_metadata is not None else 1 + qo_indptr = ( + decode_metadata.query_start_loc if decode_metadata is not None else None + ) aiter.pa_fwd_asm( Q=q, K=k_cache, @@ -336,9 +351,11 @@ def paged_attention_asm_plugin_mode( block_tables_stride0=attn_metadata.plugin_metadata.block_table[ :num_decodes ].stride(0), + max_qlen=max_qlen, K_QScale=k_scale, V_QScale=v_scale, out_=out[:num_decode_tokens], + qo_indptr=qo_indptr, high_precision=0, ) @@ -706,12 +723,13 @@ def forward_impl_plugin_mode( extend_tokens_slice = slice( num_decode_tokens, num_decode_tokens + num_extend_tokens ) + extend_reqs_slice = slice(num_decodes, num_decodes + num_extends) extend_querys = query[extend_tokens_slice] extend_keys = key[extend_tokens_slice] extend_values = value[extend_tokens_slice] extend_outputs = output[extend_tokens_slice] extend_block_table = attn_metadata.plugin_metadata.block_table[ - extend_tokens_slice + extend_reqs_slice ] extend_slot_mapping = attn_metadata.plugin_metadata.slot_mapping[ extend_tokens_slice @@ -745,6 +763,7 @@ def forward_impl_plugin_mode( v_cache=new_value_cache, k_scale=k_scale, v_scale=v_scale, + num_decodes=num_decodes, out=output_actual_tokens[:num_decode_tokens], attn_metadata=attn_metadata, ) @@ -757,6 +776,7 @@ def forward_impl_plugin_mode( v_cache=new_value_cache, k_scale=k_scale, v_scale=v_scale, + num_decodes=num_decodes, out=output_actual_tokens[:num_decode_tokens], attn_metadata=attn_metadata, ) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 07aafa4a5..8eb48da4c 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Optional from dataclasses import dataclass @@ -71,6 +72,45 @@ def _normalize_sglang_parallel_config( return tp_size, 1, 0, tp_rank +def _build_atom_speculative_config_from_vllm(vllm_spec_config: Any): + """Translate vLLM's SpeculativeConfig into ATOM's SpeculativeConfig. + + Reuses vLLM's already-loaded draft hf_config (skips a second disk fetch + in ATOM SpeculativeConfig.__post_init__) but still runs ATOM's + hf_config_override on it — so MTP model_type remap, n_routed_experts + backfill (Qwen families), and architecture rewrite all land on the + draft config in one place. Mirrors how standalone ATOM MTP exposes + the draft hf_config via atom_config.speculative_config. + + The draft hf_config is deepcopied first because hf_config_override + mutates `architectures` to ATOM's standalone naming (e.g. + "Qwen3NextMTPModel"), which differs from vLLM's registry name + ("Qwen3NextMTP"). Mutating in place would make vLLM's later draft + architecture lookup fail. + """ + if vllm_spec_config is None: + return None + + from atom.config import SpeculativeConfig + + draft_model_config = getattr(vllm_spec_config, "draft_model_config", None) + draft_hf_config = getattr(draft_model_config, "hf_config", None) + if draft_hf_config is not None: + draft_hf_config = copy.deepcopy(draft_hf_config) + model_path = getattr(draft_model_config, "model", None) or getattr( + vllm_spec_config, "model", None + ) + + return SpeculativeConfig( + method=getattr(vllm_spec_config, "method", "") or "", + model=model_path, + num_speculative_tokens=getattr( + vllm_spec_config, "num_speculative_tokens", None + ), + draft_model_hf_config=draft_hf_config, + ) + + def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: from atom.config import Config, CompilationConfig @@ -117,6 +157,10 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: max_num_batched_tokens = vllm_scheduler_config.max_num_batched_tokens + atom_speculative_config = _build_atom_speculative_config_from_vllm( + getattr(config, "speculative_config", None) + ) + return Config( model=vllm_model_config.model, trust_remote_code=getattr(vllm_model_config, "trust_remote_code", False), @@ -140,6 +184,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: master_addr=None, enable_dp_attention=False, plugin_config=plugin_config, + speculative_config=atom_speculative_config, ) diff --git a/atom/plugin/vllm/attention_backend/attention_gdn.py b/atom/plugin/vllm/attention_backend/attention_gdn.py index b6158a086..87a2f2f9f 100644 --- a/atom/plugin/vllm/attention_backend/attention_gdn.py +++ b/atom/plugin/vllm/attention_backend/attention_gdn.py @@ -22,6 +22,7 @@ from atom.model_ops.fla_ops.fused_sigmoid_gating import ( fused_sigmoid_gating_delta_rule_update, ) + from atom.utils import envs from torch import nn @@ -385,7 +386,13 @@ def forward( ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( ssm_state.dtype ) - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + # Only write directly when there are no spec tokens. With spec + # decode active, mixed_qkv was index_select'd by non_spec_token_indx + # so core_attn_out_non_spec has fewer rows than num_actual_tokens. + # The merge below (index_copy_) handles the scatter back to the + # correct slot positions. + if spec_sequence_masks is None: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) elif attn_metadata.num_decodes > 0: o = core_attn_out[: attn_metadata.num_decode_tokens] if USE_FLYDSL_GDR: diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index c2b990c18..4eada7d4c 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -35,7 +35,9 @@ logger = logging.getLogger("atom") - +_MTP_MASK_INPUT_ARCH: set[str] = { + "DeepSeekMTPModel", +} _ATOM_MODEL_CLASSES: dict[str, str] = { "LlamaForCausalLM": "atom.models.llama:LlamaForCausalLM", "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", @@ -47,6 +49,7 @@ "GlmMoeDsaForCausalLM": "atom.models.deepseek_v2:GlmMoeDsaForCausalLM", "DeepSeekMTPModel": "atom.models.deepseek_mtp:DeepSeekMTP", "Qwen3NextForCausalLM": "atom.models.qwen3_next:Qwen3NextForCausalLM", + "Qwen3NextMTP": "atom.models.qwen3_next_mtp:Qwen3NextMTP", "Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5MoeForConditionalGeneration_", "Qwen3_5ForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5ForConditionalGeneration_", "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration_", @@ -121,6 +124,7 @@ def __init_subclass__(cls, *args, **kwargs): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + from atom.config import get_current_atom_config _set_framework_backbone("vllm") @@ -140,19 +144,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.ignore_unexpected_suffixes: list[str] = [] self.vllm_config = vllm_config - self.atom_config = generate_atom_config_for_plugin_mode(vllm_config) self.is_mtp = False speculative_config = getattr(vllm_config, "speculative_config", None) if speculative_config is not None: spec_method = speculative_config.method self.is_mtp = spec_method == "mtp" - _prepare_env(atom_config=self.atom_config) - main_model_arch = vllm_config.model_config.architectures[0] model_arch = _select_model_arch(vllm_config) self.is_mtp_draft_model = self.is_mtp and model_arch != main_model_arch + if self.is_mtp_draft_model: + self.atom_config = get_current_atom_config() + else: + self.atom_config = generate_atom_config_for_plugin_mode(vllm_config) self.model_arch = model_arch + _prepare_env(atom_config=self.atom_config) model_cls = _get_atom_model_cls(model_arch) module_remapping = getattr(model_cls, "packed_modules_mapping", {}) weights_mapper = getattr(model_cls, "hf_to_atom_mapper", {}) @@ -182,9 +188,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logger.info(f"Construct ATOM model {model_arch} for vLLM plugin mode") self.model = model_cls(self.atom_config) - self._adapt_mtp_layers_for_vllm() - # Mirror nested attributes required by vLLM speculative decoding. - self._expose_spec_decode_attrs() + + if model_arch in _MTP_MASK_INPUT_ARCH: + self._adapt_mtp_layers_for_vllm() + # Mirror nested attributes required by vLLM speculative decoding. + self._expose_spec_decode_attrs() # For sparse MLA, register the Indexer's DeepseekV32IndexerCache as # a virtual subclass of vLLM's AttentionLayerBase so vLLM can discover @@ -192,7 +200,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._register_indexer_caches_with_vllm() if self.model is None: - model_arch = vllm_config.model_config.architectures[0] raise ValueError( f"The model {model_arch} is not supported by model impl backend atom" ) @@ -309,8 +316,7 @@ def _register_indexer_caches_with_vllm(self): if prefix not in vllm_sfc: vllm_sfc[prefix] = module logger.info( - f"Registered indexer cache in vLLM static_forward_context: " - f"{prefix}" + f"Registered indexer cache in vLLM static_forward_context: {prefix}" ) else: logger.warning( @@ -397,7 +403,6 @@ def forward( inputs_embeds=inputs_embeds, **model_kwargs, ) - if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -412,12 +417,12 @@ def load_weights( is_mtp_draft_model = self.model_arch in { "DeepSeekMTPModel", - "Qwen3NextMTPModel", + "Qwen3NextMTP", } draft_hf_config = None if is_mtp_draft_model: draft_model_config = getattr( - getattr(self.vllm_config, "speculative_config", None), + getattr(self.atom_config, "speculative_config", None), "draft_model_config", None, ) @@ -452,7 +457,6 @@ class ATOMMoEForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... class ATOMForConditionalGeneration( ATOMModelBase, VllmModelForTextGeneration, SupportsMultiModal, SupportsMRoPE ): - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: """ diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 9ef76e601..91e241e9a 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -30,6 +30,7 @@ "GlmMoeDsaForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, "DeepSeekMTPModel": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, "Qwen3NextForCausalLM": "atom.models.qwen3_next:Qwen3NextForCausalLMVllm", + "Qwen3NextMTP": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, "Qwen3_5ForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5MoeForConditionalGeneration", "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration", diff --git a/recipes/atom_vllm/Qwen3.5.md b/recipes/atom_vllm/Qwen3.5.md index 94a900e07..4e3b8c077 100644 --- a/recipes/atom_vllm/Qwen3.5.md +++ b/recipes/atom_vllm/Qwen3.5.md @@ -18,6 +18,7 @@ The ATOM vLLM plugin backend keeps the standard vLLM CLI, server APIs, and gener export AITER_QUICK_REDUCE_QUANTIZATION=INT4 export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 export ATOM_USE_CUSTOM_ALL_GATHER=0 +export ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0 vllm serve Qwen/Qwen3.5-35B-A3B-FP8 \ --host localhost \ @@ -37,6 +38,7 @@ vllm serve Qwen/Qwen3.5-35B-A3B-FP8 \ export AITER_QUICK_REDUCE_QUANTIZATION=INT4 export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 export ATOM_USE_CUSTOM_ALL_GATHER=0 +export ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0 vllm serve Qwen/Qwen3.5-397B-A17B-FP8 \ --host localhost \ @@ -56,6 +58,7 @@ vllm serve Qwen/Qwen3.5-397B-A17B-FP8 \ export AITER_QUICK_REDUCE_QUANTIZATION=INT4 export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 export ATOM_USE_CUSTOM_ALL_GATHER=0 +export ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0 vllm serve amd/Qwen3.5-397B-A17B-MXFP4 \ --host localhost \ @@ -69,10 +72,10 @@ vllm serve amd/Qwen3.5-397B-A17B-MXFP4 \ --no-enable-prefix-caching ``` -**Important**: The following three environment variables are required for Qwen3.5: +**Important**: The following environment variables are required for Qwen3.5: -- `ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1`: Disables ATOM attention plugin to use vLLM's implementation for full attention layers (required because Qwen3.5 uses a hybrid architecture with both linear attention (GatedDeltaNet) and full attention layers) - `ATOM_USE_CUSTOM_ALL_GATHER=0`: Disables custom all-gather for compatibility with Qwen3.5 model architecture +- `ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0`: Disables FP8 blockscale weight preshuffle - `AITER_QUICK_REDUCE_QUANTIZATION=INT4`: **Performance optimization** - enables INT4 quantization for quick reduce operations, which can significantly improve TTFT (Time To First Token) performance. **Note**: This optimization may introduce a risk of accuracy degradation. For accuracy-critical workloads, consider validating with your specific use case. ## Step 3: Performance Benchmark @@ -133,8 +136,8 @@ Reference result (TP=4): ## Key Environment Variables -- `ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1`: **Required** - disables ATOM attention plugin to use vLLM's implementation for full attention layers - `ATOM_USE_CUSTOM_ALL_GATHER=0`: **Required** - disables custom all-gather for compatibility with Qwen3.5 model architecture +- `ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0`: **Required** - disables FP8 blockscale weight preshuffle - `AITER_QUICK_REDUCE_QUANTIZATION=INT4`: **Performance optimization** - enables INT4 quantization for quick reduce operations - **Benefit**: Significantly improves TTFT (Time To First Token) performance by reducing communication overhead during tensor parallelism all-reduce operations - **Risk**: May cause slight accuracy degradation due to lower quantization precision diff --git a/recipes/atom_vllm/Qwen3Next.md b/recipes/atom_vllm/Qwen3Next.md index e22f80d1c..019e297f6 100644 --- a/recipes/atom_vllm/Qwen3Next.md +++ b/recipes/atom_vllm/Qwen3Next.md @@ -17,6 +17,8 @@ The ATOM vLLM plugin backend keeps the standard vLLM CLI, server APIs, and gener ```bash export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 export ATOM_USE_CUSTOM_ALL_GATHER=0 +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +export ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0 vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 \ --host localhost \ @@ -31,8 +33,26 @@ vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 \ --no-enable-prefix-caching ``` -**Important**: `ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1` is required for Qwen3-Next because it uses a hybrid architecture with both linear attention (GatedDeltaNet) and full attention layers. This env var ensures full attention layers use vLLM's default implementation. +### Qwen3-Next-80B-A3B-Instruct-FP8 MTP (TP=1, MI355X) +```bash +export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 +export ATOM_USE_CUSTOM_ALL_GATHER=0 +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +export ATOM_FP8_BLOCKSCALE_WEIGHT_PRESHUFFLE=0 +vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 \ + --host localhost \ + --port 8000 \ + --tensor-parallel-size 1 \ + --kv-cache-dtype fp8 \ + --gpu_memory_utilization 0.9 \ + --async-scheduling \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ + --max-model-len 16384 \ + --max-num-batched-tokens 32768 \ + --speculative-config '{"num_speculative_tokens":1, "method": "mtp"}' \ + --no-enable-prefix-caching +``` ## Step 3: Performance Benchmark Users can use the default vllm bench commands for performance benchmarking. @@ -70,9 +90,6 @@ lm_eval --model local-completions \ --num_fewshot 3 ``` -## Key Environment Variables - -- `ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1`: **Required** - disables ATOM attention plugin to use vLLM's implementation for full attention layers ## Architecture Notes