From 9ad32b1d90ef35c6780e2023aabc115e2f33496e Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Tue, 31 Mar 2026 17:19:15 +0800 Subject: [PATCH] feat: add AMD ROCm support via AITER CK flash attention backend Add AMD Instinct MI300X GPU support using AITER Composable Kernel (CK) flash attention backend, which compiles to native AMD ISA and delivers ~25% better steady-state performance than Triton kernels. Dispatch priority: FA3 > AITER CK > FA2 > RuntimeError. Tested on MI300X (gfx942) with ROCm 7.2 + AITER 0.1.13, producing identical outputs to the NVIDIA FA3 path. Made-with: Cursor --- Matrix-Game-3/README.md | 19 +++++++++++++-- Matrix-Game-3/wan/modules/attention.py | 32 ++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/Matrix-Game-3/README.md b/Matrix-Game-3/README.md index 136affc..725caa8 100644 --- a/Matrix-Game-3/README.md +++ b/Matrix-Game-3/README.md @@ -22,12 +22,13 @@ In addition, the model trained on a combination of unreal and real-world data, a ## Requirements It supports one gpu or multi-gpu inference. We tested this repo on the following setup: -* A/H series GPUs are tested. +* NVIDIA A/H series GPUs are tested. +* AMD Instinct MI300X GPUs are also supported (ROCm 7.x + AITER). * Linux operating system. * 64 GB RAM. ## ⚙️ Quick Start -### Installation +### Installation (NVIDIA) Create a conda environment and install dependencies: ``` conda create -n matrix-game-3.0 python=3.12 -y @@ -39,6 +40,20 @@ cd Matrix-Game-3.0 pip install -r requirements.txt ``` +### Installation (AMD ROCm) +For AMD GPUs (e.g. MI300X) with ROCm 7.x: +```bash +conda create -n matrix-game-3.0 python=3.10 -y +conda activate matrix-game-3.0 +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2 +git clone https://github.com/SkyworkAI/Matrix-Game-3.0.git +cd Matrix-Game-3.0 +grep -v -E "^torch|flash.attn" requirements.txt | pip install -r /dev/stdin +pip install opencv-python-headless +# Install AITER (AMD flash attention CK backend) +pip install aiter # or: git clone https://github.com/ROCm/aiter && cd aiter && git submodule update --init 3rdparty/composable_kernel && pip install . +``` + ### Model Download ``` pip install "huggingface_hub[cli]" diff --git a/Matrix-Game-3/wan/modules/attention.py b/Matrix-Game-3/wan/modules/attention.py index 9c1f7e2..4648539 100644 --- a/Matrix-Game-3/wan/modules/attention.py +++ b/Matrix-Game-3/wan/modules/attention.py @@ -13,10 +13,18 @@ except ModuleNotFoundError: FLASH_ATTN_3_AVAILABLE = False +try: + import importlib as _il + _aiter_mha = _il.import_module('aiter.ops.mha') + _aiter_flash_attn_varlen = _aiter_mha.flash_attn_varlen_func + AITER_AVAILABLE = True +except Exception: + AITER_AVAILABLE = False + try: import flash_attn FLASH_ATTN_2_AVAILABLE = True -except ModuleNotFoundError: +except Exception: FLASH_ATTN_2_AVAILABLE = False import warnings @@ -110,8 +118,19 @@ def half(x): softmax_scale=softmax_scale, causal=causal, deterministic=deterministic).unflatten(0, (b, lq)) - else: - assert FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE + elif AITER_AVAILABLE: + cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True) + cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True) + x = _aiter_flash_attn_varlen( + q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=lq, max_seqlen_k=lk, + softmax_scale=softmax_scale, causal=causal, + ).unflatten(0, (b, lq)) + elif FLASH_ATTN_2_AVAILABLE: x = flash_attn.flash_attn_varlen_func( q=q, k=k, @@ -127,6 +146,11 @@ def half(x): causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) + else: + raise RuntimeError( + 'No flash attention backend available. Install one of: ' + 'flash-attn (NVIDIA), or aiter (AMD ROCm 7.x: pip install aiter)' + ) return x.type(out_dtype) @@ -149,7 +173,7 @@ def attention( version=None, ): global _WARNED_FA_DISABLED - if version != '0' and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE): + if version != '0' and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE or AITER_AVAILABLE): return flash_attention( q=q, k=k,