Skip to content

Commit 157cb85

Browse files
authored
Merge pull request #61 from andyluo7/add-rocm-mi300x-support
feat: AMD Instinct MI300X + MI355X (gfx942/gfx950) ROCm support
2 parents 0d6b38a + e88018d commit 157cb85

5 files changed

Lines changed: 112 additions & 11 deletions

File tree

docs/rocm-mi300x-test-results.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# TurboQuant on AMD Instinct MI300X & MI355X (ROCm/HIP)
2+
3+
## Summary
4+
5+
TurboQuant KV cache compression (turbo2/turbo3/turbo4) builds and runs correctly on AMD Instinct MI300X (gfx942) and MI355X (gfx950). MI300X requires zero code changes. MI355X requires adding CDNA4 arch defines to the HIP vendor header.
6+
7+
## Test Environment
8+
9+
| Component | MI300X | MI355X |
10+
|-----------|--------|--------|
11+
| GPU | MI300X (gfx942), 192 GB HBM3 | MI355X (gfx950), 288 GB HBM3e |
12+
| ROCm | 7.0.2 | 7.0.1 |
13+
| Wave Size | 64 | 64 |
14+
| Build | `-DAMDGPU_TARGETS="gfx942"` | `-DAMDGPU_TARGETS="gfx950"` |
15+
| Model | Qwen2.5-1.5B Q4_K_M (1.04 GiB) | same |
16+
17+
## WHT Kernel Correctness
18+
19+
Standalone roundtrip test (forward WHT → inverse WHT) confirms the Walsh-Hadamard Transform kernel works correctly on HIP with 64-wide wavefronts:
20+
21+
```
22+
=== TurboQuant WHT Roundtrip Test (HIP/gfx942) ===
23+
Total elements: 512 (4 heads x 128 dim)
24+
Forward WHT zeros: 0 / 512
25+
Roundtrip max error: 2.980232e-07
26+
Roundtrip RMSE: 6.816018e-08
27+
Result: PASS ✅
28+
```
29+
30+
The kernel uses shared memory + `__syncthreads()` (no warp shuffles), so it works correctly with GCN's 64-thread wavefronts without modification.
31+
32+
## Performance Results
33+
34+
### MI300X (single GPU, Qwen2.5-1.5B Q4_K_M)
35+
36+
| KV Cache | pp512 (tok/s) | tg128 (tok/s) | Prefill vs f16 | Decode vs f16 |
37+
|----------|--------------|--------------|----------------|---------------|
38+
| f16 | 24,453 ± 230 | 181.2 ± 2.0 | baseline | baseline |
39+
| turbo3 | ~25,200 | ~160 | **+3%** | 88% |
40+
| turbo4 | 25,427 ± 17 | 161.1 ± 0.2 | **+4%** | 89% |
41+
42+
### MI355X (single GPU, Qwen2.5-1.5B Q4_K_M)
43+
44+
| KV Cache | pp512 (tok/s) | tg128 (tok/s) | Prefill vs f16 | Decode vs f16 |
45+
|----------|--------------|--------------|----------------|---------------|
46+
| f16+FA | 40,013 ± 902 | 254.5 ± 1.0 | baseline | baseline |
47+
| turbo3 | 39,140 ± 475 | 162.3 ± 0.1 | 98% | 64% |
48+
| turbo4 | 39,232 ± 508 | 214.1 ± 0.7 | 98% | **84%** |
49+
50+
### Key Observations
51+
52+
1. **MI300X prefill is faster with TurboQuant** (+3-4%) — less KV cache data to write to HBM.
53+
2. **MI300X decode at 88-89% of f16** — consistent with Apple Silicon community results.
54+
3. **MI355X turbo4 decode at 84%** — turbo4 outperforms turbo3 in decode due to simpler 4-bit dequant.
55+
4. **MI355X turbo3 decode at 64%** — the 3-bit codebook + sign extraction is more expensive on gfx950.
56+
5. **MI355X non-FA MMQ path crashes** (xf32 MFMA issue) — turbo types force FA and work correctly.
57+
58+
## Build Instructions
59+
60+
```bash
61+
git clone https://github.com/TheTom/llama-cpp-turboquant.git
62+
cd llama-cpp-turboquant
63+
git checkout feature/turboquant-kv-cache
64+
65+
# MI300X (gfx942) — works without code changes
66+
cmake -B build -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx942"
67+
cmake --build build --config Release -j
68+
69+
# MI355X (gfx950) — requires CDNA4 define patch (see commit)
70+
cmake -B build -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx950"
71+
cmake --build build --config Release -j
72+
73+
# Test
74+
HIP_VISIBLE_DEVICES=0 ./build/bin/llama-bench \
75+
-m model.gguf -ctk turbo3 -ctv turbo3 -ngl 99 -r 3 -p 512 -n 128
76+
```
77+
78+
## Code Changes for gfx950 (MI355X)
79+
80+
Three files modified to add CDNA4 (gfx950) architecture support:
81+
82+
1. **`ggml/src/ggml-cuda/vendors/hip.h`** — Add `CDNA4` define for `__gfx950__`, include in `CDNA` family
83+
2. **`ggml/src/ggml-cuda/common.cuh`** — Add `GGML_CUDA_CC_CDNA4` constant and `GGML_CUDA_CC_IS_CDNA4` macro
84+
3. **`ggml/src/ggml-cuda/mma.cuh`** — Route CDNA4 to compatible MFMA instructions (bf16_1k, i32x16x32_i8, f32x16x4f32 — NOT xf32 which doesn't exist on gfx950)
85+
86+
## Known Limitations
87+
88+
- **MI355X non-FA MMQ crashes**: The default (non-flash-attention) matrix multiply path crashes on gfx950 due to the xf32 MFMA instruction (`mfma_f32_16x16x8_xf32`) not being available. TurboQuant types force flash attention and work correctly. Standard f16/q8_0 KV cache types need `-fa 1` flag on MI355X.
89+
- **llama-cli text output**: Interactive mode produces empty tokens on ROCm (display issue), but `llama-bench` confirms computation is correct.
90+
91+
## Tested By
92+
93+
Andy Luo (@andyluo7)
94+
- AMD Instinct MI300X (gfx942), ROCm 7.0.2 — April 2026
95+
- AMD Instinct MI355X (gfx950), ROCm 7.0.1 — April 2026

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
6868
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
6969
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
70+
#define GGML_CUDA_CC_CDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x950) // MI350X/MI355X
7071

7172
// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
7273
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
@@ -87,7 +88,8 @@
8788
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
8889
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
8990
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
90-
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
91+
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_CDNA4)
92+
#define GGML_CUDA_CC_IS_CDNA4(cc) (cc >= GGML_CUDA_CC_CDNA4 && cc < GGML_CUDA_CC_RDNA1)
9193

9294
// Moore Threads
9395
#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons
@@ -802,7 +804,7 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
802804
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
803805
#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
804806
// ROCm does not support fp8 in software on devices with fp8 hardware,
805-
// but CDNA3 supports only e4m3_fnuz (no inf).
807+
// but CDNA3 supports only e4m3_fnuz (no inf). CDNA4 (gfx950) uses standard e4m3fn.
806808
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
807809
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
808810
return static_cast<float>(xf) / 2;

ggml/src/ggml-cuda/mma.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,7 @@ namespace ggml_cuda_mma {
10251025
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
10261026
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
10271027
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
1028-
#elif defined(CDNA2) || defined(CDNA1)
1028+
#elif defined(CDNA4) || defined(CDNA2) || defined(CDNA1)
10291029
#pragma unroll
10301030
for (int i = 0; i < 2; ++i) {
10311031
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
@@ -1187,7 +1187,7 @@ namespace ggml_cuda_mma {
11871187
#elif defined(AMD_MFMA_AVAILABLE)
11881188
using floatx4_t = __attribute__((ext_vector_type(4))) float;
11891189
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1190-
#if defined(CDNA3) || defined(CDNA2)
1190+
#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2)
11911191
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
11921192
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
11931193
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
@@ -1216,12 +1216,12 @@ namespace ggml_cuda_mma {
12161216
#if defined(AMD_MFMA_AVAILABLE)
12171217
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
12181218
int32x4_t * acc = (int32x4_t *) D.x;
1219-
#if defined(CDNA3)
1219+
#if defined(CDNA4) || defined(CDNA3)
12201220
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
12211221
((int64_t *) B.x)[0],
12221222
acc[0],
12231223
0, 0, 0);
1224-
#elif defined(CDNA2) || defined(CDNA)
1224+
#elif defined(CDNA2) || defined(CDNA1)
12251225
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
12261226
B.x[0],
12271227
acc[0],
@@ -1295,12 +1295,12 @@ namespace ggml_cuda_mma {
12951295
#if defined(AMD_MFMA_AVAILABLE)
12961296
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
12971297
int32x16_t * acc = (int32x16_t *) D.x;
1298-
#if defined(CDNA3)
1298+
#if defined(CDNA4) || defined(CDNA3)
12991299
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
13001300
((int64_t *) B.x)[0],
13011301
acc[0],
13021302
0, 0, 0);
1303-
#elif defined(CDNA2) || defined(CDNA)
1303+
#elif defined(CDNA2) || defined(CDNA1)
13041304
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
13051305
B.x[0],
13061306
acc[0],

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3629,7 +3629,7 @@ static __global__ void mul_mat_q(
36293629
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
36303630
return;
36313631
}
3632-
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3632+
#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
36333633

36343634
constexpr int ITER_K = get_iter_k(type);
36353635

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@
211211
#define GCN
212212
#endif // defined(GCN5) || defined(GCN4)
213213

214+
#if defined(__gfx950__)
215+
#define CDNA4
216+
#endif // defined(__gfx950__)
217+
214218
#if defined(__gfx942__)
215219
#define CDNA3
216220
#endif // defined(__gfx942__)
@@ -223,9 +227,9 @@
223227
#define CDNA1
224228
#endif // defined(__gfx908__)
225229

226-
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
230+
#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
227231
#define CDNA // For the entire family
228-
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
232+
#endif // defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
229233

230234
#if defined(__GFX12__)
231235
#define RDNA4

0 commit comments

Comments
 (0)