diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index d2b01a291..cd1f7e39c 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -3,6 +3,9 @@ # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. # Reuses the dedicated ROCm image and converts MXFP8 MoE weights to 128x128 # block FP8 at load time. Block size 128 is mandatory for MSA sparse attention. +# The second runtime patch carries the profiled sparse-attention, indexer, MoE, +# router, and collective changes. Only TP8 enables the pinned AITER Gemma fusion; +# EP keeps the faster native collectives. # Keep the default BF16 KV cache on gfx942: the checkpoint has no calibrated # q/prob scales for ROCm FP8 attention, and vLLM's fallback scale of 1.0 # corrupts model accuracy. @@ -43,26 +46,144 @@ if [[ -z "$VLLM_PACKAGE_ROOT" || ! -d "$VLLM_PACKAGE_ROOT/vllm" ]]; then exit 1 fi -MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch" -if [[ ! -f "$MXFP8_PATCH" ]]; then - echo "MI300X MXFP8 patch is missing: $MXFP8_PATCH" >&2 - exit 1 -fi +apply_vllm_patch() { + local patch_label="$1" + local patch_path="$2" + local -a patch_check_args=( + --batch + --silent + -d "$VLLM_PACKAGE_ROOT" + -p1 + --dry-run + ) -PATCH_CHECK_ARGS=(--batch --silent -d "$VLLM_PACKAGE_ROOT" -p1 --dry-run) -if patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$MXFP8_PATCH"; then - echo "MI300X MXFP8 patch is already fully applied" -elif patch "${PATCH_CHECK_ARGS[@]}" --forward < "$MXFP8_PATCH"; then - if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"; then - echo "Failed to apply the MI300X MXFP8 patch" >&2 + if [[ ! -f "$patch_path" ]]; then + echo "$patch_label patch is missing: $patch_path" >&2 exit 1 fi -else - echo "Installed vLLM is neither cleanly patchable nor fully patched" >&2 - exit 1 -fi -if ! patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$MXFP8_PATCH"; then - echo "MI300X MXFP8 patch verification failed" >&2 + if patch "${patch_check_args[@]}" --reverse --forward < "$patch_path"; then + echo "$patch_label patch is already fully applied" + elif patch "${patch_check_args[@]}" --forward < "$patch_path"; then + if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$patch_path"; then + echo "Failed to apply the $patch_label patch" >&2 + exit 1 + fi + else + echo "Installed vLLM cannot cleanly apply the $patch_label patch" >&2 + exit 1 + fi + if ! patch "${patch_check_args[@]}" --reverse --forward < "$patch_path"; then + echo "$patch_label patch verification failed" >&2 + exit 1 + fi +} + +download_verified() { + local url="$1" + local sha256="$2" + local output="$3" + local temporary="${output}.tmp.$$" + + if [[ -f "$output" ]] \ + && printf '%s %s\n' "$sha256" "$output" \ + | sha256sum --check --status; then + return 0 + fi + rm -f "$output" "$temporary" + if ! curl \ + --fail \ + --location \ + --retry 5 \ + --retry-delay 2 \ + --output "$temporary" \ + "$url"; then + echo "Failed to download $url" >&2 + return 1 + fi + if ! printf '%s %s\n' "$sha256" "$temporary" | sha256sum --check --status; then + echo "SHA256 verification failed for $url" >&2 + rm -f "$temporary" + return 1 + fi + mv "$temporary" "$output" +} + +setup_tp_aiter_gemma_fusion() { + export VLLM_ROCM_USE_AITER=0 + export VLLM_ROCM_USE_AITER_MOE=0 + export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 + export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=0 + + # The fused collective wins on the profiled TP8 decode shapes, but loses + # at both EP boundaries. Keep every other AITER backend disabled. + if [[ "$DP_ATTENTION" == "true" || "$EP_SIZE" -gt 1 || "$TP" -ne 8 ]]; then + echo "Using native collectives for MiniMax M3 EP/non-TP8" + return 0 + fi + + local aiter_commit="a40c487b3c01dc03fd3872d65b1f7404f669471f" + local cache_root="${XDG_CACHE_HOME:-$HOME/.cache}/inferencex/minimax-m3-aiter" + local aiter_archive="$cache_root/aiter-${aiter_commit}.tar.gz" + local aiter_root="/tmp/aiter-${aiter_commit}" + local flydsl_wheel="$cache_root/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl" + + mkdir -p "$cache_root" || return 1 + download_verified \ + "https://codeload.github.com/ROCm/aiter/tar.gz/${aiter_commit}" \ + "8cf142a4210e7a6fb88211b1a521c789f652e9f819ac6a0218cdeebc18f4808d" \ + "$aiter_archive" || return 1 + download_verified \ + "https://files.pythonhosted.org/packages/59/16/c87972f06b8f9a9b6ab08b598d706b687a969750df7131fc27aebae1a87a/flydsl-0.2.1-cp312-cp312-manylinux_2_27_x86_64.whl" \ + "98aa84678a515535283bf1a4b3e491c6f38de1fe16452dc8bfa44e9bd28ca99c" \ + "$flydsl_wheel" || return 1 + + rm -rf "$aiter_root" + mkdir -p "$aiter_root" || return 1 + tar \ + --extract \ + --gzip \ + --file "$aiter_archive" \ + --directory "$aiter_root" \ + --strip-components 1 || return 1 + printf 'develop\n' > "$aiter_root/aiter/install_mode" || return 1 + python3 -m pip install \ + --disable-pip-version-check \ + --no-index \ + --no-deps \ + "$flydsl_wheel" || return 1 + + export PYTHONPATH="$aiter_root${PYTHONPATH:+:$PYTHONPATH}" + export AITER_JIT_DIR="$cache_root/jit" + export TORCH_EXTENSIONS_DIR="$cache_root/torch-extensions" + export AITER_REBUILD=0 + export MAX_JOBS=32 + export VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM=1 + mkdir -p "$AITER_JIT_DIR" "$TORCH_EXTENSIONS_DIR" + + ( + flock -w 1800 9 || { + echo "Timed out waiting for the MiniMax M3 AITER build lock" >&2 + exit 1 + } + python3 - <<'PY' +import inspect + +from aiter.dist.device_communicators.custom_all_reduce import CustomAllreduce + +assert "gemma_norm" in inspect.signature(CustomAllreduce.fused_ar_rms).parameters +PY + ) 9> "$cache_root/build.lock" || return 1 +} + +PATCH_DIR="$(dirname "$0")" +apply_vllm_patch \ + "MI300X block-FP8 conversion" \ + "$PATCH_DIR/minimaxm3_mi300x_mxfp8.patch" +apply_vllm_patch \ + "MI300X profile-guided kernels and collectives" \ + "$PATCH_DIR/minimaxm3_mi300x_profiled.patch" +if ! setup_tp_aiter_gemma_fusion; then + echo "Failed to install the pinned TP-only AITER collective" >&2 exit 1 fi @@ -91,11 +212,19 @@ elif [ "$EP_SIZE" -gt 1 ]; then PARALLEL_ARGS+=(--enable-expert-parallel) fi +SCHEDULER_ARGS=() +if (( ISL >= 8192 && CONC >= 16 )); then + # The 32K budget keeps long-prefill chunks large enough to avoid starving + # decode at the measured 8k1k c16/c128/c256 and 32k1k c16 points. + SCHEDULER_ARGS+=(--max-num-batched-tokens 32768) +fi + start_gpu_monitor set -x vllm serve "$MODEL" --port "$PORT" \ "${PARALLEL_ARGS[@]}" \ + "${SCHEDULER_ARGS[@]}" \ --block-size 128 \ --no-enable-prefix-caching \ --language-model-only \ diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_profiled.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_profiled.patch new file mode 100644 index 000000000..701ebade6 --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_profiled.patch @@ -0,0 +1,1049 @@ +diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py +index d744da0b8..73266ec07 100644 +--- a/vllm/_aiter_ops.py ++++ b/vllm/_aiter_ops.py +@@ -55,0 +56 @@ class AiterCustomAllreduceProto(Protocol): ++ disabled: bool +@@ -71,0 +73 @@ class AiterCustomAllreduceProto(Protocol): ++ gemma_norm: bool = False, +@@ -822 +824 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( +-def _rocm_aiter_fused_allreduce_rmsnorm_impl( ++def _aiter_fused_allreduce_use_1stage( +@@ -824,7 +826,4 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( +- residual: torch.Tensor, +- weight: torch.Tensor, +- epsilon: float, +-) -> tuple[torch.Tensor, torch.Tensor]: +- aiter_ar = rocm_aiter_ops.get_aiter_allreduce() +- assert aiter_ar is not None, "aiter allreduce must be initialized" +- ++ aiter_ar: AiterCustomAllreduceProto, ++ *, ++ gemma_norm: bool, ++) -> bool: +@@ -852,0 +852,21 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( ++ if use_1stage and gemma_norm and world_size == 8 and hidden_dim == 6144: ++ from vllm.platforms.rocm import on_gfx942 ++ ++ # MiniMax M3 profiling on TP8 MI300X shows the two-stage fused kernel ++ # is faster even for one token, where the generic heuristic picks one ++ # stage (17.65us vs 20.26us). ++ if on_gfx942(): ++ return False ++ return use_1stage ++ ++ ++def _rocm_aiter_fused_allreduce_rmsnorm( ++ input_: torch.Tensor, ++ residual: torch.Tensor, ++ weight: torch.Tensor, ++ epsilon: float, ++ *, ++ gemma_norm: bool, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ aiter_ar = rocm_aiter_ops.get_aiter_allreduce() ++ assert aiter_ar is not None, "aiter allreduce must be initialized" +@@ -854 +874 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( +- result = aiter_ar.fused_ar_rms( ++ use_1stage = _aiter_fused_allreduce_use_1stage( +@@ -856,5 +876,2 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( +- residual, +- w=weight, +- eps=epsilon, +- registered=torch.cuda.is_current_stream_capturing(), +- use_1stage=use_1stage, ++ aiter_ar, ++ gemma_norm=gemma_norm, +@@ -861,0 +879,21 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( ++ ++ if gemma_norm: ++ result = aiter_ar.fused_ar_rms( ++ input_, ++ residual, ++ w=weight, ++ eps=epsilon, ++ registered=torch.cuda.is_current_stream_capturing(), ++ use_1stage=use_1stage, ++ gemma_norm=True, ++ ) ++ else: ++ # Keep the legacy call signature for older AITER builds. ++ result = aiter_ar.fused_ar_rms( ++ input_, ++ residual, ++ w=weight, ++ eps=epsilon, ++ registered=torch.cuda.is_current_stream_capturing(), ++ use_1stage=use_1stage, ++ ) +@@ -865,0 +904,30 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( ++def _rocm_aiter_fused_allreduce_rmsnorm_impl( ++ input_: torch.Tensor, ++ residual: torch.Tensor, ++ weight: torch.Tensor, ++ epsilon: float, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ return _rocm_aiter_fused_allreduce_rmsnorm( ++ input_, ++ residual, ++ weight, ++ epsilon, ++ gemma_norm=False, ++ ) ++ ++ ++def _rocm_aiter_fused_allreduce_gemma_rmsnorm_impl( ++ input_: torch.Tensor, ++ residual: torch.Tensor, ++ weight: torch.Tensor, ++ epsilon: float, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ return _rocm_aiter_fused_allreduce_rmsnorm( ++ input_, ++ residual, ++ weight, ++ epsilon, ++ gemma_norm=True, ++ ) ++ ++ +@@ -1456,0 +1525,2 @@ class rocm_aiter_ops: ++ VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM: Controls only the ++ fused all-reduce + Gemma RMSNorm operation. +@@ -1518,0 +1589,3 @@ class rocm_aiter_ops: ++ _FUSED_ALLREDUCE_GEMMA_RMSNORM_ENABLED = ( ++ envs.VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM ++ ) +@@ -1536,0 +1610 @@ class rocm_aiter_ops: ++ _FUSED_ALLREDUCE_GEMMA_RMSNORM_SUPPORTED: bool | None = None +@@ -1539,0 +1614 @@ class rocm_aiter_ops: ++ _CUSTOM_ALL_REDUCE_INIT_ATTEMPTED = False +@@ -1552,0 +1628,3 @@ class rocm_aiter_ops: ++ cls._FUSED_ALLREDUCE_GEMMA_RMSNORM_ENABLED = ( ++ envs.VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM ++ ) +@@ -1648,0 +1727,27 @@ class rocm_aiter_ops: ++ @classmethod ++ @if_aiter_supported ++ def has_fused_allreduce_gemma_rmsnorm(cls) -> bool: ++ if cls._FUSED_ALLREDUCE_GEMMA_RMSNORM_SUPPORTED is None: ++ try: ++ import inspect ++ ++ from aiter.dist.device_communicators.custom_all_reduce import ( ++ CustomAllreduce as AiterCustomAllreduce, ++ ) ++ ++ cls._FUSED_ALLREDUCE_GEMMA_RMSNORM_SUPPORTED = ( ++ "gemma_norm" ++ in inspect.signature(AiterCustomAllreduce.fused_ar_rms).parameters ++ ) ++ except (ImportError, AttributeError, TypeError, ValueError): ++ cls._FUSED_ALLREDUCE_GEMMA_RMSNORM_SUPPORTED = False ++ return cls._FUSED_ALLREDUCE_GEMMA_RMSNORM_SUPPORTED ++ ++ @classmethod ++ @if_aiter_supported ++ def is_fused_allreduce_gemma_rmsnorm_enabled(cls) -> bool: ++ return ( ++ cls._FUSED_ALLREDUCE_GEMMA_RMSNORM_ENABLED ++ and cls.has_fused_allreduce_gemma_rmsnorm() ++ ) ++ +@@ -1766,0 +1872,3 @@ class rocm_aiter_ops: ++ if cls._CUSTOM_ALL_REDUCE_INIT_ATTEMPTED: ++ return ++ cls._CUSTOM_ALL_REDUCE_INIT_ATTEMPTED = True +@@ -1772 +1880,5 @@ class rocm_aiter_ops: +- cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) ++ cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce( ++ group, ++ device, ++ max_size=cls._ALL_REDUCE_MAX_SIZE, ++ ) +@@ -1784,0 +1897 @@ class rocm_aiter_ops: ++ cls._CUSTOM_ALL_REDUCE_INIT_ATTEMPTED = False +@@ -2023,0 +2137,6 @@ class rocm_aiter_ops: ++ direct_register_custom_op( ++ op_name="rocm_aiter_fused_allreduce_gemma_rmsnorm", ++ op_func=_rocm_aiter_fused_allreduce_gemma_rmsnorm_impl, ++ fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, ++ ) ++ +@@ -2089,0 +2209,4 @@ class rocm_aiter_ops: ++ @staticmethod ++ def get_fused_allreduce_gemma_rmsnorm_op() -> OpOverload: ++ return torch.ops.vllm.rocm_aiter_fused_allreduce_gemma_rmsnorm.default ++ +diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py +index 8bd6e9215..0531cae87 100644 +--- a/vllm/distributed/parallel_state.py ++++ b/vllm/distributed/parallel_state.py +@@ -593 +593,3 @@ class GroupCoordinator: +- if rocm_aiter_ops.is_enabled(): ++ if rocm_aiter_ops.is_enabled() or ( ++ rocm_aiter_ops.is_fused_allreduce_gemma_rmsnorm_enabled() ++ ): +@@ -2012,0 +2015,7 @@ def destroy_model_parallel(): ++ from vllm.platforms import current_platform ++ ++ if current_platform.is_rocm(): ++ with contextlib.suppress(Exception): ++ from vllm._aiter_ops import rocm_aiter_ops ++ ++ rocm_aiter_ops.destroy_aiter_allreduce() +diff --git a/vllm/envs.py b/vllm/envs.py +index b2bdf11d6..b176ae0d0 100755 +--- a/vllm/envs.py ++++ b/vllm/envs.py +@@ -123,0 +124 @@ if TYPE_CHECKING: ++ VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM: bool = False +@@ -1140,0 +1142,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ++ # Use only AITER's fused tensor-parallel all-reduce + Gemma RMSNorm. ++ # This switch is independent of VLLM_ROCM_USE_AITER so models can opt in ++ # without enabling AITER attention, GEMM, or MoE kernels. ++ "VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM": lambda: ( ++ os.getenv("VLLM_ROCM_USE_AITER_FUSED_ALLREDUCE_GEMMA_RMSNORM", "False").lower() ++ in ("true", "1") ++ ), +diff --git a/vllm/model_executor/layers/fused_allreduce_gemma_rms_norm.py b/vllm/model_executor/layers/fused_allreduce_gemma_rms_norm.py +index e49e135b2..fd038b84a 100644 +--- a/vllm/model_executor/layers/fused_allreduce_gemma_rms_norm.py ++++ b/vllm/model_executor/layers/fused_allreduce_gemma_rms_norm.py +@@ -11,4 +11,3 @@ helper drives it directly (no torch.compile pass) for models that run eager. +-Scope: attention output only, no quantization. When the flashinfer fast path is +-not applicable (TP==1, flashinfer/NVSwitch unavailable, unsupported dtype, or an +-oversize batch) it falls back to ``all_reduce`` + ``GemmaRMSNorm``, which is +-numerically identical to the unfused model path. ++When the platform-specific fast path is not applicable it falls back to ++``all_reduce`` + ``GemmaRMSNorm``, which is numerically identical to the ++unfused model path. +@@ -24,0 +24 @@ from vllm.distributed.parallel_state import ( ++from vllm.logger import init_logger +@@ -25,0 +26,3 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm ++from vllm.platforms import current_platform ++ ++logger = init_logger(__name__) +@@ -52,0 +56,82 @@ _FI_SUPPORTED_DTYPES = (torch.bfloat16, torch.float16) ++_MI300X_M3_AITER_MAX_TOKENS = 1536 ++ ++ ++def _aiter_gemma_shape_is_profitable( ++ num_tokens: int, ++ hidden_size: int, ++ tp_size: int, ++ *, ++ on_gfx942: bool, ++) -> bool: ++ # TP8 MiniMax M3 crosses over between 1536 and 2048 tokens on MI300X. ++ # Large chunked-prefill batches are faster with NCCL plus the native norm. ++ return not ( ++ on_gfx942 ++ and tp_size == 8 ++ and hidden_size == 6144 ++ and num_tokens > _MI300X_M3_AITER_MAX_TOKENS ++ ) ++ ++ ++def initialize_aiter_fused_allreduce_gemma_rms_norm() -> bool: ++ """Initialize the opt-in AITER communicator before graph capture.""" ++ if not current_platform.is_rocm() or get_tensor_model_parallel_world_size() == 1: ++ return False ++ ++ from vllm._aiter_ops import rocm_aiter_ops ++ ++ if not rocm_aiter_ops.is_fused_allreduce_gemma_rmsnorm_enabled(): ++ return False ++ if rocm_aiter_ops.get_aiter_allreduce() is None: ++ device_index = torch.accelerator.current_device_index() ++ device = torch.device("cuda", 0 if device_index is None else device_index) ++ rocm_aiter_ops.initialize_aiter_allreduce(get_tp_group().cpu_group, device) ++ ++ aiter_ar = rocm_aiter_ops.get_aiter_allreduce() ++ if aiter_ar is None or aiter_ar.disabled: ++ logger.warning_once( ++ "AITER fused all-reduce + Gemma RMSNorm was requested but its " ++ "communicator could not be initialized; using the unfused path." ++ ) ++ return False ++ return True ++ ++ ++def _can_use_aiter( ++ hidden_states: torch.Tensor, ++ residual: torch.Tensor, ++ norm: GemmaRMSNorm, ++) -> bool: ++ if ( ++ not current_platform.is_rocm() ++ or not hidden_states.is_cuda ++ or hidden_states.dtype not in _FI_SUPPORTED_DTYPES ++ or hidden_states.dim() != 2 ++ or not hidden_states.is_contiguous() ++ or residual.shape != hidden_states.shape ++ or residual.dtype != hidden_states.dtype ++ or not residual.is_contiguous() ++ or norm.weight.shape != (hidden_states.shape[-1],) ++ or norm.weight.dtype != hidden_states.dtype ++ or not norm.weight.is_contiguous() ++ ): ++ return False ++ ++ from vllm._aiter_ops import rocm_aiter_ops ++ from vllm.platforms.rocm import on_gfx942 ++ ++ if not rocm_aiter_ops.is_fused_allreduce_gemma_rmsnorm_enabled(): ++ return False ++ if not _aiter_gemma_shape_is_profitable( ++ hidden_states.shape[0], ++ hidden_states.shape[1], ++ get_tensor_model_parallel_world_size(), ++ on_gfx942=on_gfx942(), ++ ): ++ return False ++ aiter_ar = rocm_aiter_ops.get_aiter_allreduce() ++ return bool( ++ aiter_ar is not None ++ and not aiter_ar.disabled ++ and aiter_ar.should_custom_ar(hidden_states) ++ ) +@@ -106,0 +192,2 @@ def fused_allreduce_gemma_rms_norm( ++ *, ++ allow_aiter: bool = True, +@@ -119,0 +207,10 @@ def fused_allreduce_gemma_rms_norm( ++ if allow_aiter and _can_use_aiter(hidden_states, residual, norm): ++ from vllm._aiter_ops import rocm_aiter_ops ++ ++ return rocm_aiter_ops.get_fused_allreduce_gemma_rmsnorm_op()( ++ input_=hidden_states, ++ residual=residual, ++ weight=norm.weight, ++ epsilon=norm.variance_epsilon, ++ ) ++ +diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py +index 905a9bea3..7f98a17ed 100644 +--- a/vllm/model_executor/layers/fused_moe/config.py ++++ b/vllm/model_executor/layers/fused_moe/config.py +@@ -1284,0 +1285 @@ class FusedMoEConfig: ++ mxfp8_block_fp8_on_fnuz: bool = False +diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +index b1bfee7fc..84cabece3 100644 +--- a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json ++++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +@@ -1,0 +2,54 @@ ++ "1": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 1 ++ }, ++ "2": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 2 ++ }, ++ "4": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0 ++ }, ++ "8": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 8, ++ "num_warps": 8, ++ "num_stages": 1, ++ "waves_per_eu": 0 ++ }, ++ "16": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 1, ++ "num_warps": 8, ++ "num_stages": 1, ++ "waves_per_eu": 1 ++ }, ++ "32": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 16, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 1 ++ }, +@@ -5,2 +59,2 @@ +- "BLOCK_SIZE_K": 128, +- "GROUP_SIZE_M": 4, ++ "BLOCK_SIZE_K": 256, ++ "GROUP_SIZE_M": 8, +@@ -8,2 +62,2 @@ +- "num_stages": 2, +- "waves_per_eu": 2 ++ "num_stages": 1, ++ "waves_per_eu": 0 +@@ -15 +69 @@ +- "GROUP_SIZE_M": 1, ++ "GROUP_SIZE_M": 8, +@@ -17,2 +71,2 @@ +- "num_stages": 2, +- "waves_per_eu": 1 ++ "num_stages": 1, ++ "waves_per_eu": 0 +@@ -21 +75 @@ +- "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_M": 32, +@@ -24,4 +78,4 @@ +- "GROUP_SIZE_M": 1, +- "num_warps": 4, +- "num_stages": 2, +- "waves_per_eu": 2 ++ "GROUP_SIZE_M": 8, ++ "num_warps": 2, ++ "num_stages": 1, ++ "waves_per_eu": 0 +diff --git a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py +index d81458b37..cd771b7e9 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py +@@ -26,0 +27,3 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ++from vllm.model_executor.layers.fused_moe.moe_fused_mul_sum import ( ++ moe_fused_mul_sum, ++) +@@ -53,0 +57,20 @@ from vllm.utils.multi_stream_utils import maybe_execute_in_parallel ++def _use_minimax_m3_mi300x_ep_route_compaction( ++ moe_config: FusedMoEConfig, ++ quant_config: FusedMoEQuantConfig, ++) -> bool: ++ """Match the profiled MiniMax M3 block-FP8 EP8 shape on gfx942.""" ++ return ( ++ current_platform.is_fp8_fnuz() ++ and moe_config.ep_size == 8 ++ and moe_config.num_experts == 128 ++ and moe_config.num_local_experts == 16 ++ and moe_config.experts_per_token == 4 ++ and moe_config.hidden_dim == 6144 ++ and moe_config.intermediate_size_per_partition == 3072 ++ and moe_config.activation == MoEActivation.SWIGLUOAI_UNINTERLEAVE ++ and not moe_config.is_lora_enabled ++ and quant_config.use_fp8_w8a8 ++ and quant_config.block_shape == [128, 128] ++ ) ++ ++ +@@ -74,0 +98,3 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular): ++ self.compact_minimax_m3_mi300x_ep_routes = ( ++ _use_minimax_m3_mi300x_ep_route_compaction(moe_config, quant_config) ++ ) +@@ -228,0 +255,3 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular): ++ compact_ep_routes = ( ++ self.compact_minimax_m3_mi300x_ep_routes and expert_map is not None ++ ) +@@ -274,0 +304,2 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular): ++ ignore_invalid_experts=compact_ep_routes, ++ num_local_experts=E if compact_ep_routes else None, +@@ -479,2 +510,12 @@ class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular): +- # separate function is required for MoE + LoRA +- self.moe_sum(intermediate_cache3, output) ++ if compact_ep_routes: ++ moe_fused_mul_sum( ++ intermediate_cache3, ++ topk_weights, ++ outputs=output, ++ topk_ids=topk_ids, ++ expert_map=expert_map, ++ apply_weights=False, ++ ) ++ else: ++ # separate function is required for MoE + LoRA ++ self.moe_sum(intermediate_cache3, output) +diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py +index 49957c8f5..6a206ba05 100644 +--- a/vllm/model_executor/layers/fused_moe/fused_moe.py ++++ b/vllm/model_executor/layers/fused_moe/fused_moe.py +@@ -1436,0 +1437 @@ def _prepare_expert_assignment( ++ num_local_experts: int | None = None, +@@ -1470,0 +1472 @@ def _prepare_expert_assignment( ++ num_local_experts=num_local_experts, +diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py +index 225484385..6b82032d9 100644 +--- a/vllm/model_executor/layers/fused_moe/layer.py ++++ b/vllm/model_executor/layers/fused_moe/layer.py +@@ -113,0 +114 @@ def FusedMoE( ++ mxfp8_block_fp8_on_fnuz: bool = False, +@@ -172,0 +174,2 @@ def FusedMoE( ++ mxfp8_block_fp8_on_fnuz: Requantize serialized MXFP8 MoE weights to ++ block FP8 on FNUZ devices +@@ -322,0 +326 @@ def FusedMoE( ++ mxfp8_block_fp8_on_fnuz=mxfp8_block_fp8_on_fnuz, +diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +index 7fc8bfcf8..99404ed95 100644 +--- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py ++++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +@@ -17,0 +18 @@ def moe_align_block_size( ++ num_local_experts: int | None = None, +@@ -45,0 +47,3 @@ def moe_align_block_size( ++ - num_local_experts: The number of experts retained by ``expert_map``. ++ When invalid experts are ignored, this tightens the output allocation ++ from the global expert count to the local expert count. +@@ -74 +78,14 @@ def moe_align_block_size( +- max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) ++ padding_experts = num_experts ++ if ( ++ ignore_invalid_experts ++ and expert_map is not None ++ and num_local_experts is not None ++ ): ++ if not 0 < num_local_experts <= num_experts: ++ raise ValueError( ++ "num_local_experts must be in (0, num_experts], got " ++ f"{num_local_experts} for num_experts={num_experts}." ++ ) ++ padding_experts = num_local_experts ++ ++ max_num_tokens_padded = topk_ids.numel() + padding_experts * (block_size - 1) +diff --git a/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py b/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py +index 768f41db8..0465112ee 100644 +--- a/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py ++++ b/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py +@@ -19,0 +20 @@ def moe_fused_mul_sum_kernel( ++ apply_weights: tl.constexpr, +@@ -41 +42,4 @@ def moe_fused_mul_sum_kernel( +- b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32) ++ if apply_weights: ++ b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32) ++ else: ++ b_val = 1.0 +@@ -140,0 +145 @@ def moe_fused_mul_sum( ++ apply_weights: bool = True, +@@ -156,0 +162,2 @@ def moe_fused_mul_sum( ++ apply_weights: Multiply each route by ``topk_weights`` before summing. ++ Set to false when the expert GEMM already applied router weights. +@@ -193,0 +201 @@ def moe_fused_mul_sum( ++ apply_weights, +diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py +index f230b4d57..eb43ff702 100644 +--- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py ++++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py +@@ -18,2 +18 @@ class GateLinear(ReplicatedLinear): +- 2. fp32 specialized kernel (SM90+, bf16/fp32 in, fp32 out, +- M<=32, H=3072, E=256) ++ 2. fp32 specialized kernels (SM90+ or gfx942, bf16/fp32 in, fp32 out) +@@ -35,0 +35,2 @@ class GateLinear(ReplicatedLinear): ++ ROCM_FP32_SUPPORTED_SHAPES = {(6144, 128)} ++ ROCM_FP32_MAX_TOKENS = 16 +@@ -86,0 +88,6 @@ class GateLinear(ReplicatedLinear): ++ self.allow_rocm_fp32_router_gemm = ( ++ not bias ++ and self.weight.dtype == torch.float32 ++ and _on_gfx942() ++ and (input_size, output_size) in self.ROCM_FP32_SUPPORTED_SHAPES ++ ) +@@ -124 +131 @@ class GateLinear(ReplicatedLinear): +- # Tier 2: fp32 specialized kernel (H=3072, E=256, M<=32) ++ # Tier 2a: CUDA fp32 specialized kernel (M<=32) +@@ -133,0 +141,5 @@ class GateLinear(ReplicatedLinear): ++ # Tier 2b: gfx942 M3 fp32 router kernel (BF16 input, M<=16) ++ if self.allow_rocm_fp32_router_gemm and x.dtype == torch.bfloat16: ++ output = torch.ops.vllm.rocm_fp32_router_gemm_dispatch(x, self.weight) ++ return output, None ++ +@@ -148,0 +161,9 @@ _FP32_ROUTER_GEMM_MAX_TOKENS = GateLinear.FP32_MAX_TOKENS ++_ROCM_FP32_ROUTER_GEMM_MAX_TOKENS = GateLinear.ROCM_FP32_MAX_TOKENS ++ ++ ++def _on_gfx942() -> bool: ++ if not current_platform.is_rocm(): ++ return False ++ from vllm.platforms.rocm import on_gfx942 ++ ++ return on_gfx942() +@@ -176,0 +198,32 @@ direct_register_custom_op( ++ ++ ++def rocm_fp32_router_gemm_dispatch_impl( ++ x: torch.Tensor, weight: torch.Tensor ++) -> torch.Tensor: ++ if ( ++ 0 < x.shape[0] <= _ROCM_FP32_ROUTER_GEMM_MAX_TOKENS ++ and x.dtype == torch.bfloat16 ++ and weight.dtype == torch.float32 ++ and x.is_contiguous() ++ and weight.is_contiguous() ++ ): ++ from vllm.model_executor.layers.fused_moe.router.rocm_fp32_router_gemm import ( ++ rocm_fp32_router_gemm, ++ ) ++ ++ return rocm_fp32_router_gemm(x, weight) ++ return torch.nn.functional.linear(x.float(), weight) ++ ++ ++def rocm_fp32_router_gemm_dispatch_fake( ++ x: torch.Tensor, weight: torch.Tensor ++) -> torch.Tensor: ++ return x.new_empty((x.shape[0], weight.shape[0]), dtype=torch.float32) ++ ++ ++if current_platform.is_rocm(): ++ direct_register_custom_op( ++ op_name="rocm_fp32_router_gemm_dispatch", ++ op_func=rocm_fp32_router_gemm_dispatch_impl, ++ fake_impl=rocm_fp32_router_gemm_dispatch_fake, ++ ) +diff --git a/vllm/model_executor/layers/fused_moe/router/rocm_fp32_router_gemm.py b/vllm/model_executor/layers/fused_moe/router/rocm_fp32_router_gemm.py +new file mode 100644 +index 000000000..749d64030 +--- /dev/null ++++ b/vllm/model_executor/layers/fused_moe/router/rocm_fp32_router_gemm.py +@@ -0,0 +1,84 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++"""Small-batch BF16 activation by FP32 router projection for gfx942.""" ++ ++import torch ++ ++from vllm.triton_utils import tl, triton ++ ++_HIDDEN_SIZE = 6144 ++_NUM_EXPERTS = 128 ++_BLOCK_M = 4 ++_BLOCK_K = 256 ++ ++ ++@triton.jit ++def _rocm_fp32_router_gemm_kernel( ++ x_ptr, ++ weight_ptr, ++ output_ptr, ++ num_tokens, ++ HIDDEN_SIZE: tl.constexpr, ++ NUM_EXPERTS: tl.constexpr, ++ BLOCK_M: tl.constexpr, ++ BLOCK_K: tl.constexpr, ++): ++ expert = tl.program_id(0) ++ token_offsets = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) ++ hidden_offsets = tl.arange(0, BLOCK_K) ++ accumulator = tl.zeros((BLOCK_M,), dtype=tl.float32) ++ ++ for hidden_start in tl.static_range(0, HIDDEN_SIZE, BLOCK_K): ++ hidden_states = tl.load( ++ x_ptr ++ + token_offsets[:, None] * HIDDEN_SIZE ++ + hidden_start ++ + hidden_offsets[None, :], ++ mask=token_offsets[:, None] < num_tokens, ++ other=0.0, ++ ).to(tl.float32) ++ weight = tl.load( ++ weight_ptr + expert * HIDDEN_SIZE + hidden_start + hidden_offsets, ++ ).to(tl.float32) ++ accumulator += tl.sum(hidden_states * weight[None, :], axis=1) ++ ++ tl.store( ++ output_ptr + token_offsets * NUM_EXPERTS + expert, ++ accumulator, ++ mask=token_offsets < num_tokens, ++ ) ++ ++ ++def rocm_fp32_router_gemm( ++ hidden_states: torch.Tensor, ++ router_weight: torch.Tensor, ++) -> torch.Tensor: ++ if hidden_states.ndim != 2 or hidden_states.shape[1] != _HIDDEN_SIZE: ++ raise ValueError("hidden_states must have shape [num_tokens, 6144]") ++ if router_weight.shape != (_NUM_EXPERTS, _HIDDEN_SIZE): ++ raise ValueError("router_weight must have shape [128, 6144]") ++ if hidden_states.dtype != torch.bfloat16: ++ raise ValueError("hidden_states must be bfloat16") ++ if router_weight.dtype != torch.float32: ++ raise ValueError("router_weight must be float32") ++ if not hidden_states.is_contiguous() or not router_weight.is_contiguous(): ++ raise ValueError("hidden_states and router_weight must be contiguous") ++ ++ num_tokens = hidden_states.shape[0] ++ output = torch.empty( ++ (num_tokens, _NUM_EXPERTS), ++ dtype=torch.float32, ++ device=hidden_states.device, ++ ) ++ _rocm_fp32_router_gemm_kernel[(_NUM_EXPERTS, triton.cdiv(num_tokens, _BLOCK_M))]( ++ hidden_states, ++ router_weight, ++ output, ++ num_tokens, ++ HIDDEN_SIZE=_HIDDEN_SIZE, ++ NUM_EXPERTS=_NUM_EXPERTS, ++ BLOCK_M=_BLOCK_M, ++ BLOCK_K=_BLOCK_K, ++ num_warps=4, ++ ) ++ return output +diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +index cf8bd73fb..c7810720c 100644 +--- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py ++++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +@@ -4,0 +5 @@ from contextlib import nullcontext ++from functools import cache +@@ -59,0 +61,9 @@ logger = init_logger(__name__) ++@cache ++def _use_gfx942_fused_shared_routed_add() -> bool: ++ if not current_platform.is_rocm(): ++ return False ++ from vllm.platforms.rocm import on_gfx942 ++ ++ return on_gfx942() ++ ++ +@@ -256,0 +267 @@ class MoERunner(MoERunnerInterface): ++ reduce_results: bool = True, +@@ -263,0 +275 @@ class MoERunner(MoERunnerInterface): ++ self.reduce_results = reduce_results +@@ -403,0 +416,23 @@ class MoERunner(MoERunnerInterface): ++ def _combine_shared_and_routed_outputs( ++ self, ++ shared_output: torch.Tensor | None, ++ fused_output: torch.Tensor, ++ ) -> torch.Tensor: ++ if ( ++ shared_output is not None ++ and self.routed_scaling_factor == 2.0 ++ and shared_output.dtype == torch.bfloat16 ++ and fused_output.dtype == torch.bfloat16 ++ and self.routed_output_transform is None ++ and _use_gfx942_fused_shared_routed_add() ++ ): ++ return torch.add(shared_output, fused_output, alpha=2.0) ++ ++ shared_output, fused_output = self._maybe_apply_routed_scale_to_output( ++ shared_output, fused_output ++ ) ++ fused_output = self.apply_routed_output_transform(fused_output) ++ if shared_output is not None: ++ return shared_output + fused_output ++ return fused_output ++ +@@ -410,0 +446,9 @@ class MoERunner(MoERunnerInterface): ++ @property ++ def output_is_reduced(self) -> bool: ++ """Whether forward returns an output complete across TP/EP ranks.""" ++ return ( ++ self.moe_config.is_sequence_parallel ++ or self._fused_output_is_reduced ++ or self.reduce_results ++ ) ++ +@@ -447 +491,2 @@ class MoERunner(MoERunnerInterface): +- not self.moe_config.is_sequence_parallel ++ self.reduce_results ++ and not self.moe_config.is_sequence_parallel +@@ -700,11 +745 @@ class MoERunner(MoERunnerInterface): +- shared_output, fused_output = self._maybe_apply_routed_scale_to_output( +- shared_output, fused_output +- ) +- +- # Apply output transform (e.g. latent -> full dim) +- fused_output = self.apply_routed_output_transform(fused_output) +- +- if shared_output is not None: +- result = shared_output + fused_output +- else: +- result = fused_output ++ result = self._combine_shared_and_routed_outputs(shared_output, fused_output) +diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py +index 9b9d73f7b..0efeeca74 100644 +--- a/vllm/model_executor/layers/quantization/modelopt.py ++++ b/vllm/model_executor/layers/quantization/modelopt.py +@@ -1889 +1889,3 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): +- self.requantize_mxfp8_to_block_fp8 = current_platform.is_fp8_fnuz() ++ self.requantize_mxfp8_to_block_fp8 = ( ++ current_platform.is_fp8_fnuz() and self.moe.mxfp8_block_fp8_on_fnuz ++ ) +diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py +index de3fb059a..2a70c9270 100644 +--- a/vllm/model_executor/layers/vocab_parallel_embedding.py ++++ b/vllm/model_executor/layers/vocab_parallel_embedding.py +@@ -228,0 +229 @@ class VocabParallelEmbedding(PluggableLayer): ++ enable_tp: shard the embedding across the tensor-parallel group. +@@ -241,0 +243 @@ class VocabParallelEmbedding(PluggableLayer): ++ enable_tp: bool = True, +@@ -246,2 +248,2 @@ class VocabParallelEmbedding(PluggableLayer): +- tp_rank = get_tensor_model_parallel_rank() +- self.tp_size = get_tensor_model_parallel_world_size() ++ tp_rank = get_tensor_model_parallel_rank() if enable_tp else 0 ++ self.tp_size = get_tensor_model_parallel_world_size() if enable_tp else 1 +@@ -495,3 +497,5 @@ class VocabParallelEmbedding(PluggableLayer): +- # Reduce across all the model parallel GPUs. +- output = tensor_model_parallel_all_reduce(output_parallel) +- return output ++ # Reduce across all model-parallel GPUs only when the embedding is ++ # actually sharded. Replicated embeddings already produce full output. ++ if self.tp_size > 1: ++ return tensor_model_parallel_all_reduce(output_parallel) ++ return output_parallel +diff --git a/vllm/models/minimax_m3/amd/model.py b/vllm/models/minimax_m3/amd/model.py +index 27650c8e6..a1200da27 100644 +--- a/vllm/models/minimax_m3/amd/model.py ++++ b/vllm/models/minimax_m3/amd/model.py +@@ -38,0 +39 @@ from vllm.model_executor.layers.fused_allreduce_gemma_rms_norm import ( ++ initialize_aiter_fused_allreduce_gemma_rms_norm, +@@ -92,0 +94 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ++from vllm.platforms import current_platform +@@ -119,0 +122,20 @@ def _is_moe_layer(config: PretrainedConfig, layer_id: int) -> bool: ++def _use_mi300x_replicated_input_embedding(vllm_config: VllmConfig) -> bool: ++ config = vllm_config.model_config.hf_text_config ++ parallel_config = vllm_config.parallel_config ++ if not ( ++ config.vocab_size == 200064 ++ and config.hidden_size == 6144 ++ and vllm_config.lora_config is None ++ and parallel_config.tensor_parallel_size == 8 ++ and parallel_config.pipeline_parallel_size == 1 ++ and not parallel_config.enable_expert_parallel ++ ): ++ return False ++ if not current_platform.is_rocm(): ++ return False ++ ++ from vllm.platforms.rocm import on_gfx942 ++ ++ return on_gfx942() ++ ++ +@@ -204,0 +227 @@ class MiniMaxM3MLP(nn.Module): ++ self.reduce_results = reduce_results +@@ -244,0 +268,4 @@ class MiniMaxM3MLP(nn.Module): ++ @property ++ def output_is_reduced(self) -> bool: ++ return self.reduce_results ++ +@@ -254,0 +282 @@ class MiniMaxM3MoE(nn.Module): ++ reduce_results: bool = True, +@@ -317,0 +346,2 @@ class MiniMaxM3MoE(nn.Module): ++ mxfp8_block_fp8_on_fnuz=True, ++ runner_args={"reduce_results": reduce_results}, +@@ -337,0 +368,4 @@ class MiniMaxM3MoE(nn.Module): ++ @property ++ def output_is_reduced(self) -> bool: ++ return self.experts.output_is_reduced ++ +@@ -727,0 +762,2 @@ class MiniMaxM3DecoderLayer(nn.Module): ++ reduce_ffn_results: bool = True, ++ use_aiter_fused_norm: bool = True, +@@ -730,0 +767 @@ class MiniMaxM3DecoderLayer(nn.Module): ++ self.use_aiter_fused_norm = use_aiter_fused_norm +@@ -764,0 +802 @@ class MiniMaxM3DecoderLayer(nn.Module): ++ reduce_results=reduce_ffn_results, +@@ -771,0 +810 @@ class MiniMaxM3DecoderLayer(nn.Module): ++ reduce_results=reduce_ffn_results, +@@ -787,0 +827 @@ class MiniMaxM3DecoderLayer(nn.Module): ++ input_is_reduced: bool = True, +@@ -792,0 +833,7 @@ class MiniMaxM3DecoderLayer(nn.Module): ++ elif not input_is_reduced: ++ hidden_states, residual = fused_allreduce_gemma_rms_norm( ++ hidden_states, ++ residual, ++ self.input_layernorm, ++ allow_aiter=self.use_aiter_fused_norm, ++ ) +@@ -801 +848,4 @@ class MiniMaxM3DecoderLayer(nn.Module): +- hidden_states, residual, self.post_attention_layernorm ++ hidden_states, ++ residual, ++ self.post_attention_layernorm, ++ allow_aiter=self.use_aiter_fused_norm, +@@ -806,0 +857,5 @@ class MiniMaxM3DecoderLayer(nn.Module): ++ @property ++ def ffn_output_is_reduced(self) -> bool: ++ ffn = self.block_sparse_moe if self.is_moe_layer else self.mlp ++ return ffn.output_is_reduced ++ +@@ -817,0 +873,9 @@ class MiniMaxM3Model(nn.Module): ++ self.use_aiter_fused_norm = ( ++ not vllm_config.parallel_config.enable_expert_parallel ++ and initialize_aiter_fused_allreduce_gemma_rms_norm() ++ ) ++ self.defer_ffn_allreduce = ( ++ self.use_aiter_fused_norm ++ and vllm_config.parallel_config.pipeline_parallel_size == 1 ++ and vllm_config.parallel_config.data_parallel_size == 1 ++ ) +@@ -825,0 +890 @@ class MiniMaxM3Model(nn.Module): ++ enable_tp=not _use_mi300x_replicated_input_embedding(vllm_config), +@@ -834,0 +900,2 @@ class MiniMaxM3Model(nn.Module): ++ reduce_ffn_results=not self.defer_ffn_allreduce, ++ use_aiter_fused_norm=self.use_aiter_fused_norm, +@@ -854,0 +922 @@ class MiniMaxM3Model(nn.Module): ++ hidden_states_are_reduced = True +@@ -857 +925,18 @@ class MiniMaxM3Model(nn.Module): +- hidden_states, residual = layer(positions, hidden_states, residual) ++ hidden_states, residual = layer( ++ positions, ++ hidden_states, ++ residual, ++ input_is_reduced=hidden_states_are_reduced, ++ ) ++ hidden_states_are_reduced = layer.ffn_output_is_reduced ++ ++ if hidden_states_are_reduced: ++ hidden_states, _ = self.norm(hidden_states, residual) ++ else: ++ assert residual is not None ++ hidden_states, _ = fused_allreduce_gemma_rms_norm( ++ hidden_states, ++ residual, ++ self.norm, ++ allow_aiter=self.use_aiter_fused_norm, ++ ) +@@ -859 +943,0 @@ class MiniMaxM3Model(nn.Module): +- hidden_states, _ = self.norm(hidden_states, residual) +diff --git a/vllm/models/minimax_m3/common/ops/index_topk.py b/vllm/models/minimax_m3/common/ops/index_topk.py +index c32ff38d9..283e2fbb5 100644 +--- a/vllm/models/minimax_m3/common/ops/index_topk.py ++++ b/vllm/models/minimax_m3/common/ops/index_topk.py +@@ -371 +371,11 @@ def _decode_index_score_kernel( +- kq = tl.dot(k, q) * sm_scale_log2e # [N,H] ++ if num_idx_heads == 1: ++ # Single index head + single query token (the common M3 decode case): ++ # tl.dot here is a degenerate [N,D] x [D,1] GEMV that produces one ++ # useful output column out of the >=16-wide MFMA tile (>90% wasted) ++ # and bloats accumulator VGPRs. Replace it with a vectorized fp32 ++ # multiply + reduce over the head dim (numerically equivalent). ++ q_vec = tl.sum(q, axis=1).to(tl.float32) ++ kq = tl.sum(k.to(tl.float32) * q_vec[None, :], axis=1)[:, None] ++ else: ++ kq = tl.dot(k, q) ++ kq *= sm_scale_log2e +@@ -373 +383 @@ def _decode_index_score_kernel( +- score = tl.max(kq, axis=0) # [H] ++ score = tl.max(kq, axis=0) +diff --git a/vllm/models/minimax_m3/common/ops/sparse_attn.py b/vllm/models/minimax_m3/common/ops/sparse_attn.py +index 40287e166..8d823bff7 100644 +--- a/vllm/models/minimax_m3/common/ops/sparse_attn.py ++++ b/vllm/models/minimax_m3/common/ops/sparse_attn.py +@@ -65,2 +65,3 @@ def _sparse_attn_num_stages_kwarg() -> dict: +- "BLOCK_SIZE_QH": lambda args: args["BLOCK_SIZE_Q"] +- * triton.next_power_of_2(args["gqa_group_size"]), ++ "BLOCK_SIZE_QH": lambda args: ( ++ args["BLOCK_SIZE_Q"] * triton.next_power_of_2(args["gqa_group_size"]) ++ ), +@@ -124 +124,0 @@ def _gqa_sparse_fwd_kernel( +- off_n = tl.arange(0, BLOCK_SIZE_K) +@@ -142,6 +141,0 @@ def _gqa_sparse_fwd_kernel( +- off_q = ( +- tl.arange(0, BLOCK_SIZE_Q)[:, None] +- + pid_q_j * BLOCK_SIZE_Q +- + prefix_len +- - tl.arange(0, BLOCK_SIZE_K)[None, :] +- ) +@@ -152 +146,3 @@ def _gqa_sparse_fwd_kernel( +- for _ in range(real_topk): ++ SUB_K: tl.constexpr = BLOCK_SIZE_K // 4 ++ NUM_SUB: tl.constexpr = BLOCK_SIZE_K // SUB_K ++ for _ in tl.range(real_topk): +@@ -157,39 +153,46 @@ def _gqa_sparse_fwd_kernel( +- pos = c + off_n +- pos_mask = pos < seq_len +- k = tl.load( +- kv_cache_ptr +- + page * stride_kv_blk +- + 0 * stride_kv_kv +- + off_n[None, :] * stride_kv_pos +- + pid_kh * stride_kv_h +- + off_d[:, None] * stride_kv_d, +- mask=d_mask[:, None] & pos_mask[None, :], +- other=0.0, +- ) +- if USE_FP8: +- k = k.to(q.dtype) +- qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) +- # causal: q_abs_pos - k_off >= block_start (c) +- qk += tl.where(off_q[:, None, :] >= c, 0, float("-inf")) +- qk = tl.reshape(qk, BLOCK_SIZE_QH, BLOCK_SIZE_K) +- qk += tl.dot(q, k) * sm_scale_log2e +- qk += tl.where(pos_mask[None, :], 0, float("-inf")) +- m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) +- p = tl.exp2(qk - m_ij[:, None]) +- l_ij = tl.sum(p, axis=1) +- acc_o = acc_o * tl.exp2(m_i - m_ij)[:, None] +- v = tl.load( +- kv_cache_ptr +- + page * stride_kv_blk +- + 1 * stride_kv_kv +- + off_n[:, None] * stride_kv_pos +- + pid_kh * stride_kv_h +- + off_d[None, :] * stride_kv_d, +- mask=pos_mask[:, None] & d_mask[None, :], +- other=0.0, +- ) +- if USE_FP8: +- v = v.to(q.dtype) +- acc_o += tl.dot(p.to(v.dtype), v) +- m_i = m_ij +- lse_i = m_ij + tl.log2(tl.exp2(lse_i - m_ij) + l_ij) ++ kv_base = kv_cache_ptr + page * stride_kv_blk + pid_kh * stride_kv_h ++ ++ for sub_i in range(NUM_SUB): ++ off_sub = tl.arange(0, SUB_K) + sub_i * SUB_K ++ pos_sub = c + off_sub ++ pos_mask_sub = pos_sub < seq_len ++ k_sub = tl.load( ++ kv_base ++ + 0 * stride_kv_kv ++ + off_sub[None, :] * stride_kv_pos ++ + off_d[:, None] * stride_kv_d, ++ mask=d_mask[:, None] & pos_mask_sub[None, :], ++ other=0.0, ++ ) ++ if USE_FP8: ++ k_sub = k_sub.to(q.dtype) ++ off_q_sub = ( ++ tl.arange(0, BLOCK_SIZE_Q)[:, None] ++ + pid_q_j * BLOCK_SIZE_Q ++ + prefix_len ++ - off_sub[None, :] ++ ) ++ qk_sub_3d = tl.zeros( ++ (BLOCK_SIZE_Q, BLOCK_SIZE_H, SUB_K), dtype=tl.float32 ++ ) ++ qk_sub_3d += tl.where(off_q_sub[:, None, :] >= c, 0, float("-inf")) ++ qk_sub = tl.reshape(qk_sub_3d, BLOCK_SIZE_QH, SUB_K) ++ qk_sub += tl.dot(q, k_sub) * sm_scale_log2e ++ qk_sub += tl.where(pos_mask_sub[None, :], 0, float("-inf")) ++ m_ij = tl.maximum(m_i, tl.max(qk_sub, axis=1)) ++ p_sub = tl.exp2(qk_sub - m_ij[:, None]) ++ l_ij = tl.sum(p_sub, axis=1) ++ acc_o = acc_o * tl.exp2(m_i - m_ij)[:, None] ++ v_sub = tl.load( ++ kv_base ++ + 1 * stride_kv_kv ++ + off_sub[:, None] * stride_kv_pos ++ + off_d[None, :] * stride_kv_d, ++ mask=pos_mask_sub[:, None] & d_mask[None, :], ++ other=0.0, ++ ) ++ if USE_FP8: ++ v_sub = v_sub.to(q.dtype) ++ acc_o += tl.dot(p_sub.to(v_sub.dtype), v_sub) ++ m_i = m_ij ++ lse_i = m_ij + tl.log2(tl.exp2(lse_i - m_ij) + l_ij) +@@ -494,0 +498,3 @@ def minimax_m3_sparse_attn( ++ num_warps=1, ++ matrix_instr_nonkdim=16, ++ kpack=2,