Skip to content

Commit 87e92e2

Browse files
committed
perf(vllm): optimize MiniMax M3 MI300X inference
1 parent 6f5a399 commit 87e92e2

2 files changed

Lines changed: 1195 additions & 17 deletions

File tree

benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh

Lines changed: 146 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe.
44
# Reuses the dedicated ROCm image and converts MXFP8 MoE weights to 128x128
55
# block FP8 at load time. Block size 128 is mandatory for MSA sparse attention.
6+
# The second runtime patch carries the profiled sparse-attention, indexer, MoE,
7+
# router, and collective changes. Only TP8 enables the pinned AITER Gemma fusion;
8+
# EP keeps the faster native collectives.
69
# Keep the default BF16 KV cache on gfx942: the checkpoint has no calibrated
710
# q/prob scales for ROCm FP8 attention, and vLLM's fallback scale of 1.0
811
# corrupts model accuracy.
@@ -43,26 +46,144 @@ if [[ -z "$VLLM_PACKAGE_ROOT" || ! -d "$VLLM_PACKAGE_ROOT/vllm" ]]; then
4346
exit 1
4447
fi
4548

46-
MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch"
47-
if [[ ! -f "$MXFP8_PATCH" ]]; then
48-
echo "MI300X MXFP8 patch is missing: $MXFP8_PATCH" >&2
49-
exit 1
50-
fi
49+
apply_vllm_patch() {
50+
local patch_label="$1"
51+
local patch_path="$2"
52+
local -a patch_check_args=(
53+
--batch
54+
--silent
55+
-d "$VLLM_PACKAGE_ROOT"
56+
-p1
57+
--dry-run
58+
)
5159

52-
PATCH_CHECK_ARGS=(--batch --silent -d "$VLLM_PACKAGE_ROOT" -p1 --dry-run)
53-
if patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$MXFP8_PATCH"; then
54-
echo "MI300X MXFP8 patch is already fully applied"
55-
elif patch "${PATCH_CHECK_ARGS[@]}" --forward < "$MXFP8_PATCH"; then
56-
if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"; then
57-
echo "Failed to apply the MI300X MXFP8 patch" >&2
60+
if [[ ! -f "$patch_path" ]]; then
61+
echo "$patch_label patch is missing: $patch_path" >&2
5862
exit 1
5963
fi
60-
else
61-
echo "Installed vLLM is neither cleanly patchable nor fully patched" >&2
62-
exit 1
63-
fi
64-
if ! patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$MXFP8_PATCH"; then
65-
echo "MI300X MXFP8 patch verification failed" >&2
64+
if patch "${patch_check_args[@]}" --reverse --forward < "$patch_path"; then
65+
echo "$patch_label patch is already fully applied"
66+
elif patch "${patch_check_args[@]}" --forward < "$patch_path"; then
67+
if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$patch_path"; then
68+
echo "Failed to apply the $patch_label patch" >&2
69+
exit 1
70+
fi
71+
else
72+
echo "Installed vLLM cannot cleanly apply the $patch_label patch" >&2
73+
exit 1
74+
fi
75+
if ! patch "${patch_check_args[@]}" --reverse --forward < "$patch_path"; then
76+
echo "$patch_label patch verification failed" >&2
77+
exit 1
78+
fi
79+
}
80+
81+
download_verified() {
82+
local url="$1"
83+
local sha256="$2"
84+
local output="$3"
85+
local temporary="${output}.tmp.$$"
86+
87+
if [[ -f "$output" ]] \
88+
&& printf '%s %s\n' "$sha256" "$output" \
89+
| sha256sum --check --status; then
90+
return 0
91+
fi
92+
rm -f "$output" "$temporary"
93+
if ! curl \
94+
--fail \
95+
--location \
96+
--retry 5 \
97+
--retry-delay 2 \
98+
--output "$temporary" \
99+
"$url"; then
100+
echo "Failed to download $url" >&2
101+
return 1
102+
fi
103+
if ! printf '%s %s\n' "$sha256" "$temporary" | sha256sum --check --status; then
104+
echo "SHA256 verification failed for $url" >&2
105+
rm -f "$temporary"
106+
return 1
107+
fi
108+
mv "$temporary" "$output"
109+
}
110+
111+
setup_tp_aiter_gemma_fusion() {
112+
export VLLM_ROCM_USE_AITER=0
113+
export VLLM_ROCM_USE_AITER_MOE=0
114+
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0
115+
export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=0
116+
117+
# The fused collective wins on the profiled TP8 decode shapes, but loses
118+
# at both EP boundaries. Keep every other AITER backend disabled.
119+
if [[ "$DP_ATTENTION" == "true" || "$EP_SIZE" -gt 1 || "$TP" -ne 8 ]]; then
120+
echo "Using native collectives for MiniMax M3 EP/non-TP8"
121+
return 0
122+
fi
123+
124+
local aiter_commit="a40c487b3c01dc03fd3872d65b1f7404f669471f"
125+
local cache_root="${XDG_CACHE_HOME:-$HOME/.cache}/inferencex/minimax-m3-aiter"
126+
local aiter_archive="$cache_root/aiter-${aiter_commit}.tar.gz"
127+
local aiter_root="/tmp/aiter-${aiter_commit}"
128+
local flydsl_wheel="$cache_root/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl"
129+
130+
mkdir -p "$cache_root" || return 1
131+
download_verified \
132+
"https://codeload.github.com/ROCm/aiter/tar.gz/${aiter_commit}" \
133+
"8cf142a4210e7a6fb88211b1a521c789f652e9f819ac6a0218cdeebc18f4808d" \
134+
"$aiter_archive" || return 1
135+
download_verified \
136+
"https://files.pythonhosted.org/packages/59/16/c87972f06b8f9a9b6ab08b598d706b687a969750df7131fc27aebae1a87a/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl" \
137+
"98aa84678a515535283bf1a4b3e491c6f38de1fe16452dc8bfa44e9bd28ca99c" \
138+
"$flydsl_wheel" || return 1
139+
140+
rm -rf "$aiter_root"
141+
mkdir -p "$aiter_root" || return 1
142+
tar \
143+
--extract \
144+
--gzip \
145+
--file "$aiter_archive" \
146+
--directory "$aiter_root" \
147+
--strip-components 1 || return 1
148+
printf 'develop\n' > "$aiter_root/aiter/install_mode" || return 1
149+
python3 -m pip install \
150+
--disable-pip-version-check \
151+
--no-index \
152+
--no-deps \
153+
"$flydsl_wheel" || return 1
154+
155+
export PYTHONPATH="$aiter_root${PYTHONPATH:+:$PYTHONPATH}"
156+
export AITER_JIT_DIR="$cache_root/jit"
157+
export TORCH_EXTENSIONS_DIR="$cache_root/torch-extensions"
158+
export AITER_REBUILD=0
159+
export MAX_JOBS=32
160+
export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=1
161+
mkdir -p "$AITER_JIT_DIR" "$TORCH_EXTENSIONS_DIR"
162+
163+
(
164+
flock -w 1800 9 || {
165+
echo "Timed out waiting for the MiniMax M3 AITER build lock" >&2
166+
exit 1
167+
}
168+
python3 - <<'PY'
169+
import inspect
170+
171+
from aiter.dist.device_communicators.custom_all_reduce import CustomAllreduce
172+
173+
assert "gemma_norm" in inspect.signature(CustomAllreduce.fused_ar_rms).parameters
174+
PY
175+
) 9> "$cache_root/build.lock" || return 1
176+
}
177+
178+
PATCH_DIR="$(dirname "$0")"
179+
apply_vllm_patch \
180+
"MI300X block-FP8 conversion" \
181+
"$PATCH_DIR/minimaxm3_mi300x_mxfp8.patch"
182+
apply_vllm_patch \
183+
"MI300X profile-guided kernels and collectives" \
184+
"$PATCH_DIR/minimaxm3_mi300x_profiled.patch"
185+
if ! setup_tp_aiter_gemma_fusion; then
186+
echo "Failed to install the pinned TP-only AITER collective" >&2
66187
exit 1
67188
fi
68189

@@ -91,11 +212,19 @@ elif [ "$EP_SIZE" -gt 1 ]; then
91212
PARALLEL_ARGS+=(--enable-expert-parallel)
92213
fi
93214

215+
SCHEDULER_ARGS=()
216+
if (( ISL >= 8192 && CONC >= 16 )); then
217+
# The 32K budget keeps long-prefill chunks large enough to avoid starving
218+
# decode at the measured 8k1k c16/c128/c256 and 32k1k c16 points.
219+
SCHEDULER_ARGS+=(--max-num-batched-tokens 32768)
220+
fi
221+
94222
start_gpu_monitor
95223

96224
set -x
97225
vllm serve "$MODEL" --port "$PORT" \
98226
"${PARALLEL_ARGS[@]}" \
227+
"${SCHEDULER_ARGS[@]}" \
99228
--block-size 128 \
100229
--no-enable-prefix-caching \
101230
--language-model-only \

0 commit comments

Comments
 (0)