|
3 | 3 | # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. |
4 | 4 | # Reuses the dedicated ROCm image and converts MXFP8 MoE weights to 128x128 |
5 | 5 | # block FP8 at load time. Block size 128 is mandatory for MSA sparse attention. |
| 6 | +# The second runtime patch carries the profiled sparse-attention, indexer, MoE, |
| 7 | +# router, and collective changes. Only TP8 enables the pinned AITER Gemma fusion; |
| 8 | +# EP keeps the faster native collectives. |
6 | 9 | # Keep the default BF16 KV cache on gfx942: the checkpoint has no calibrated |
7 | 10 | # q/prob scales for ROCm FP8 attention, and vLLM's fallback scale of 1.0 |
8 | 11 | # corrupts model accuracy. |
@@ -43,26 +46,144 @@ if [[ -z "$VLLM_PACKAGE_ROOT" || ! -d "$VLLM_PACKAGE_ROOT/vllm" ]]; then |
43 | 46 | exit 1 |
44 | 47 | fi |
45 | 48 |
|
46 | | -MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch" |
47 | | -if [[ ! -f "$MXFP8_PATCH" ]]; then |
48 | | - echo "MI300X MXFP8 patch is missing: $MXFP8_PATCH" >&2 |
49 | | - exit 1 |
50 | | -fi |
| 49 | +apply_vllm_patch() { |
| 50 | + local patch_label="$1" |
| 51 | + local patch_path="$2" |
| 52 | + local -a patch_check_args=( |
| 53 | + --batch |
| 54 | + --silent |
| 55 | + -d "$VLLM_PACKAGE_ROOT" |
| 56 | + -p1 |
| 57 | + --dry-run |
| 58 | + ) |
51 | 59 |
|
52 | | -PATCH_CHECK_ARGS=(--batch --silent -d "$VLLM_PACKAGE_ROOT" -p1 --dry-run) |
53 | | -if patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$MXFP8_PATCH"; then |
54 | | - echo "MI300X MXFP8 patch is already fully applied" |
55 | | -elif patch "${PATCH_CHECK_ARGS[@]}" --forward < "$MXFP8_PATCH"; then |
56 | | - if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"; then |
57 | | - echo "Failed to apply the MI300X MXFP8 patch" >&2 |
| 60 | + if [[ ! -f "$patch_path" ]]; then |
| 61 | + echo "$patch_label patch is missing: $patch_path" >&2 |
58 | 62 | exit 1 |
59 | 63 | fi |
60 | | -else |
61 | | - echo "Installed vLLM is neither cleanly patchable nor fully patched" >&2 |
62 | | - exit 1 |
63 | | -fi |
64 | | -if ! patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$MXFP8_PATCH"; then |
65 | | - echo "MI300X MXFP8 patch verification failed" >&2 |
| 64 | + if patch "${patch_check_args[@]}" --reverse --forward < "$patch_path"; then |
| 65 | + echo "$patch_label patch is already fully applied" |
| 66 | + elif patch "${patch_check_args[@]}" --forward < "$patch_path"; then |
| 67 | + if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$patch_path"; then |
| 68 | + echo "Failed to apply the $patch_label patch" >&2 |
| 69 | + exit 1 |
| 70 | + fi |
| 71 | + else |
| 72 | + echo "Installed vLLM cannot cleanly apply the $patch_label patch" >&2 |
| 73 | + exit 1 |
| 74 | + fi |
| 75 | + if ! patch "${patch_check_args[@]}" --reverse --forward < "$patch_path"; then |
| 76 | + echo "$patch_label patch verification failed" >&2 |
| 77 | + exit 1 |
| 78 | + fi |
| 79 | +} |
| 80 | + |
| 81 | +download_verified() { |
| 82 | + local url="$1" |
| 83 | + local sha256="$2" |
| 84 | + local output="$3" |
| 85 | + local temporary="${output}.tmp.$$" |
| 86 | + |
| 87 | + if [[ -f "$output" ]] \ |
| 88 | + && printf '%s %s\n' "$sha256" "$output" \ |
| 89 | + | sha256sum --check --status; then |
| 90 | + return 0 |
| 91 | + fi |
| 92 | + rm -f "$output" "$temporary" |
| 93 | + if ! curl \ |
| 94 | + --fail \ |
| 95 | + --location \ |
| 96 | + --retry 5 \ |
| 97 | + --retry-delay 2 \ |
| 98 | + --output "$temporary" \ |
| 99 | + "$url"; then |
| 100 | + echo "Failed to download $url" >&2 |
| 101 | + return 1 |
| 102 | + fi |
| 103 | + if ! printf '%s %s\n' "$sha256" "$temporary" | sha256sum --check --status; then |
| 104 | + echo "SHA256 verification failed for $url" >&2 |
| 105 | + rm -f "$temporary" |
| 106 | + return 1 |
| 107 | + fi |
| 108 | + mv "$temporary" "$output" |
| 109 | +} |
| 110 | + |
| 111 | +setup_tp_aiter_gemma_fusion() { |
| 112 | + export VLLM_ROCM_USE_AITER=0 |
| 113 | + export VLLM_ROCM_USE_AITER_MOE=0 |
| 114 | + export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 |
| 115 | + export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=0 |
| 116 | + |
| 117 | + # The fused collective wins on the profiled TP8 decode shapes, but loses |
| 118 | + # at both EP boundaries. Keep every other AITER backend disabled. |
| 119 | + if [[ "$DP_ATTENTION" == "true" || "$EP_SIZE" -gt 1 || "$TP" -ne 8 ]]; then |
| 120 | + echo "Using native collectives for MiniMax M3 EP/non-TP8" |
| 121 | + return 0 |
| 122 | + fi |
| 123 | + |
| 124 | + local aiter_commit="a40c487b3c01dc03fd3872d65b1f7404f669471f" |
| 125 | + local cache_root="${XDG_CACHE_HOME:-$HOME/.cache}/inferencex/minimax-m3-aiter" |
| 126 | + local aiter_archive="$cache_root/aiter-${aiter_commit}.tar.gz" |
| 127 | + local aiter_root="/tmp/aiter-${aiter_commit}" |
| 128 | + local flydsl_wheel="$cache_root/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl" |
| 129 | + |
| 130 | + mkdir -p "$cache_root" || return 1 |
| 131 | + download_verified \ |
| 132 | + "https://codeload.github.com/ROCm/aiter/tar.gz/${aiter_commit}" \ |
| 133 | + "8cf142a4210e7a6fb88211b1a521c789f652e9f819ac6a0218cdeebc18f4808d" \ |
| 134 | + "$aiter_archive" || return 1 |
| 135 | + download_verified \ |
| 136 | + "https://files.pythonhosted.org/packages/59/16/c87972f06b8f9a9b6ab08b598d706b687a969750df7131fc27aebae1a87a/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl" \ |
| 137 | + "98aa84678a515535283bf1a4b3e491c6f38de1fe16452dc8bfa44e9bd28ca99c" \ |
| 138 | + "$flydsl_wheel" || return 1 |
| 139 | + |
| 140 | + rm -rf "$aiter_root" |
| 141 | + mkdir -p "$aiter_root" || return 1 |
| 142 | + tar \ |
| 143 | + --extract \ |
| 144 | + --gzip \ |
| 145 | + --file "$aiter_archive" \ |
| 146 | + --directory "$aiter_root" \ |
| 147 | + --strip-components 1 || return 1 |
| 148 | + printf 'develop\n' > "$aiter_root/aiter/install_mode" || return 1 |
| 149 | + python3 -m pip install \ |
| 150 | + --disable-pip-version-check \ |
| 151 | + --no-index \ |
| 152 | + --no-deps \ |
| 153 | + "$flydsl_wheel" || return 1 |
| 154 | + |
| 155 | + export PYTHONPATH="$aiter_root${PYTHONPATH:+:$PYTHONPATH}" |
| 156 | + export AITER_JIT_DIR="$cache_root/jit" |
| 157 | + export TORCH_EXTENSIONS_DIR="$cache_root/torch-extensions" |
| 158 | + export AITER_REBUILD=0 |
| 159 | + export MAX_JOBS=32 |
| 160 | + export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=1 |
| 161 | + mkdir -p "$AITER_JIT_DIR" "$TORCH_EXTENSIONS_DIR" |
| 162 | + |
| 163 | + ( |
| 164 | + flock -w 1800 9 || { |
| 165 | + echo "Timed out waiting for the MiniMax M3 AITER build lock" >&2 |
| 166 | + exit 1 |
| 167 | + } |
| 168 | + python3 - <<'PY' |
| 169 | +import inspect |
| 170 | +
|
| 171 | +from aiter.dist.device_communicators.custom_all_reduce import CustomAllreduce |
| 172 | +
|
| 173 | +assert "gemma_norm" in inspect.signature(CustomAllreduce.fused_ar_rms).parameters |
| 174 | +PY |
| 175 | + ) 9> "$cache_root/build.lock" || return 1 |
| 176 | +} |
| 177 | + |
| 178 | +PATCH_DIR="$(dirname "$0")" |
| 179 | +apply_vllm_patch \ |
| 180 | + "MI300X block-FP8 conversion" \ |
| 181 | + "$PATCH_DIR/minimaxm3_mi300x_mxfp8.patch" |
| 182 | +apply_vllm_patch \ |
| 183 | + "MI300X profile-guided kernels and collectives" \ |
| 184 | + "$PATCH_DIR/minimaxm3_mi300x_profiled.patch" |
| 185 | +if ! setup_tp_aiter_gemma_fusion; then |
| 186 | + echo "Failed to install the pinned TP-only AITER collective" >&2 |
66 | 187 | exit 1 |
67 | 188 | fi |
68 | 189 |
|
@@ -91,11 +212,19 @@ elif [ "$EP_SIZE" -gt 1 ]; then |
91 | 212 | PARALLEL_ARGS+=(--enable-expert-parallel) |
92 | 213 | fi |
93 | 214 |
|
| 215 | +SCHEDULER_ARGS=() |
| 216 | +if (( ISL >= 8192 && CONC >= 16 )); then |
| 217 | + # The 32K budget keeps long-prefill chunks large enough to avoid starving |
| 218 | + # decode at the measured 8k1k c16/c128/c256 and 32k1k c16 points. |
| 219 | + SCHEDULER_ARGS+=(--max-num-batched-tokens 32768) |
| 220 | +fi |
| 221 | + |
94 | 222 | start_gpu_monitor |
95 | 223 |
|
96 | 224 | set -x |
97 | 225 | vllm serve "$MODEL" --port "$PORT" \ |
98 | 226 | "${PARALLEL_ARGS[@]}" \ |
| 227 | + "${SCHEDULER_ARGS[@]}" \ |
99 | 228 | --block-size 128 \ |
100 | 229 | --no-enable-prefix-caching \ |
101 | 230 | --language-model-only \ |
|
0 commit comments