Skip to content

Commit 2b449ab

Browse files
committed
perf(vllm): optimize MiniMax M3 MI300X inference
1 parent f4960d1 commit 2b449ab

3 files changed

Lines changed: 1163 additions & 206 deletions

File tree

benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe.
44
# Reuses the dedicated ROCm image and converts MXFP8 MoE weights to 128x128
55
# block FP8 at load time. Block size 128 is mandatory for MSA sparse attention.
6+
# The second runtime patch carries the profiled sparse-attention, indexer, MoE,
7+
# router, and collective changes. Only TP8 enables the pinned AITER Gemma fusion;
8+
# EP keeps the faster native collectives.
69
# Keep the default BF16 KV cache on gfx942: the checkpoint has no calibrated
710
# q/prob scales for ROCm FP8 attention, and vLLM's fallback scale of 1.0
811
# corrupts model accuracy.
@@ -75,13 +78,114 @@ apply_vllm_patch() {
7578
fi
7679
}
7780

81+
download_verified() {
82+
local url="$1"
83+
local sha256="$2"
84+
local output="$3"
85+
local temporary="${output}.tmp.$$"
86+
87+
if [[ -f "$output" ]] \
88+
&& printf '%s %s\n' "$sha256" "$output" \
89+
| sha256sum --check --status; then
90+
return 0
91+
fi
92+
rm -f "$output" "$temporary"
93+
if ! curl \
94+
--fail \
95+
--location \
96+
--retry 5 \
97+
--retry-delay 2 \
98+
--output "$temporary" \
99+
"$url"; then
100+
echo "Failed to download $url" >&2
101+
return 1
102+
fi
103+
if ! printf '%s %s\n' "$sha256" "$temporary" | sha256sum --check --status; then
104+
echo "SHA256 verification failed for $url" >&2
105+
rm -f "$temporary"
106+
return 1
107+
fi
108+
mv "$temporary" "$output"
109+
}
110+
111+
setup_tp_aiter_gemma_fusion() {
112+
export VLLM_ROCM_USE_AITER=0
113+
export VLLM_ROCM_USE_AITER_MOE=0
114+
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0
115+
export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=0
116+
117+
# The fused collective wins on the profiled TP8 decode shapes, but loses
118+
# at both EP boundaries. Keep every other AITER backend disabled.
119+
if [[ "$DP_ATTENTION" == "true" || "$EP_SIZE" -gt 1 || "$TP" -ne 8 ]]; then
120+
echo "Using native collectives for MiniMax M3 EP/non-TP8"
121+
return 0
122+
fi
123+
124+
local aiter_commit="a40c487b3c01dc03fd3872d65b1f7404f669471f"
125+
local cache_root="${XDG_CACHE_HOME:-$HOME/.cache}/inferencex/minimax-m3-aiter"
126+
local aiter_archive="$cache_root/aiter-${aiter_commit}.tar.gz"
127+
local aiter_root="/tmp/aiter-${aiter_commit}"
128+
local flydsl_wheel="$cache_root/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl"
129+
130+
mkdir -p "$cache_root" || return 1
131+
download_verified \
132+
"https://codeload.github.com/ROCm/aiter/tar.gz/${aiter_commit}" \
133+
"8cf142a4210e7a6fb88211b1a521c789f652e9f819ac6a0218cdeebc18f4808d" \
134+
"$aiter_archive" || return 1
135+
download_verified \
136+
"https://files.pythonhosted.org/packages/59/16/c87972f06b8f9a9b6ab08b598d706b687a969750df7131fc27aebae1a87a/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl" \
137+
"98aa84678a515535283bf1a4b3e491c6f38de1fe16452dc8bfa44e9bd28ca99c" \
138+
"$flydsl_wheel" || return 1
139+
140+
rm -rf "$aiter_root"
141+
mkdir -p "$aiter_root" || return 1
142+
tar \
143+
--extract \
144+
--gzip \
145+
--file "$aiter_archive" \
146+
--directory "$aiter_root" \
147+
--strip-components 1 || return 1
148+
printf 'develop\n' > "$aiter_root/aiter/install_mode" || return 1
149+
python3 -m pip install \
150+
--disable-pip-version-check \
151+
--no-index \
152+
--no-deps \
153+
"$flydsl_wheel" || return 1
154+
155+
export PYTHONPATH="$aiter_root${PYTHONPATH:+:$PYTHONPATH}"
156+
export AITER_JIT_DIR="$cache_root/jit"
157+
export TORCH_EXTENSIONS_DIR="$cache_root/torch-extensions"
158+
export AITER_REBUILD=0
159+
export MAX_JOBS=32
160+
export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=1
161+
mkdir -p "$AITER_JIT_DIR" "$TORCH_EXTENSIONS_DIR"
162+
163+
(
164+
flock -w 1800 9 || {
165+
echo "Timed out waiting for the MiniMax M3 AITER build lock" >&2
166+
exit 1
167+
}
168+
python3 - <<'PY'
169+
import inspect
170+
171+
from aiter.dist.device_communicators.custom_all_reduce import CustomAllreduce
172+
173+
assert "gemma_norm" in inspect.signature(CustomAllreduce.fused_ar_rms).parameters
174+
PY
175+
) 9> "$cache_root/build.lock" || return 1
176+
}
177+
78178
PATCH_DIR="$(dirname "$0")"
79179
apply_vllm_patch \
80180
"MI300X block-FP8 conversion" \
81181
"$PATCH_DIR/minimaxm3_mi300x_mxfp8.patch"
82182
apply_vllm_patch \
83-
"MI300X block-FP8 EP route compaction" \
84-
"$PATCH_DIR/minimaxm3_mi300x_ep_mxfp8.patch"
183+
"MI300X profile-guided kernels and collectives" \
184+
"$PATCH_DIR/minimaxm3_mi300x_profiled.patch"
185+
if ! setup_tp_aiter_gemma_fusion; then
186+
echo "Failed to install the pinned TP-only AITER collective" >&2
187+
exit 1
188+
fi
85189

86190
if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi
87191

@@ -108,11 +212,19 @@ elif [ "$EP_SIZE" -gt 1 ]; then
108212
PARALLEL_ARGS+=(--enable-expert-parallel)
109213
fi
110214

215+
SCHEDULER_ARGS=()
216+
if (( ISL >= 8192 && CONC >= 16 )); then
217+
# The 32K budget keeps long-prefill chunks large enough to avoid starving
218+
# decode at the measured 8k1k c16/c128/c256 and 32k1k c16 points.
219+
SCHEDULER_ARGS+=(--max-num-batched-tokens 32768)
220+
fi
221+
111222
start_gpu_monitor
112223

113224
set -x
114225
vllm serve "$MODEL" --port "$PORT" \
115226
"${PARALLEL_ARGS[@]}" \
227+
"${SCHEDULER_ARGS[@]}" \
116228
--block-size 128 \
117229
--no-enable-prefix-caching \
118230
--language-model-only \

benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch

Lines changed: 0 additions & 204 deletions
This file was deleted.

0 commit comments

Comments
 (0)