|
1 | 1 | # Testing ROCm 4-bit Kernel Optimizations |
2 | 2 |
|
3 | | -This document describes how to test the RDNA/CDNA kernel optimizations for `kgemm_4bit_inference_naive` on different AMD GPU hardware. |
| 3 | +This document describes how to test the RDNA/CDNA kernel optimizations on different AMD GPU hardware. The testing is structured in 3 phases to isolate the impact of each change. |
4 | 4 |
|
5 | 5 | ## What Changed |
6 | 6 |
|
7 | | -1. **Kernel optimizations** (`csrc/kernels.cu`): Float compute path, fully unrolled dequant+FMA, replicated quant_map for bank-conflict reduction, B data prefetching, `__launch_bounds__` guard |
8 | | -2. **Fused multi-batch GEMM** (`csrc/kernels.cu`): N-loop wrapping K-loop enables fused 4-bit matmul for M>1 (small batch sizes up to M=16) |
9 | | -3. **Fixed dispatch threshold** (`bitsandbytes/autograd/_functions.py`): `FUSED_4BIT_M_LIMIT = 16` routes M<=16 through fused kernel instead of dequant+GEMM fallback. Critical for vLLM serving with concurrent requests. |
10 | | -4. **Multi-row Python backend** (`bitsandbytes/backends/cuda/ops.py`, `bitsandbytes/_ops.py`): `gemv_4bit` accepts A with multiple rows |
| 7 | +1. **Kernel optimizations** (`csrc/kernels.cu`): Float compute path, fully unrolled dequant+FMA, replicated quant_map, B data prefetching, `__launch_bounds__` |
| 8 | +2. **Fused multi-batch GEMM** (`csrc/kernels.cu`): N-loop enables fused 4-bit matmul for M>1 |
| 9 | +3. **Dispatch threshold** (`bitsandbytes/autograd/_functions.py`): `FUSED_4BIT_M_LIMIT = 16` routes M<=16 through fused kernel instead of dequant+GEMM fallback |
| 10 | +4. **Multi-row backend** (`bitsandbytes/backends/cuda/ops.py`, `bitsandbytes/_ops.py`): `gemv_4bit` accepts A with multiple rows |
11 | 11 |
|
12 | 12 | ## Prerequisites |
13 | 13 |
|
14 | 14 | ```bash |
15 | | -# ROCm 7.x with HIP support |
16 | | -# PyTorch with ROCm (2.9+) |
17 | | -# Python 3.10+ |
18 | | - |
| 15 | +# ROCm 7.x with HIP support, PyTorch with ROCm (2.9+), Python 3.10+ |
19 | 16 | pip install pytest einops scipy transformers accelerate |
20 | | -# Optional for e2e model benchmarks: |
21 | | -pip install unsloth |
| 17 | +pip install unsloth # for e2e model benchmarks |
| 18 | +# For vLLM testing: install vLLM with ROCm support |
22 | 19 | ``` |
23 | 20 |
|
24 | 21 | ## Build |
25 | 22 |
|
26 | 23 | ```bash |
27 | | -# Build for your specific GPU (replace gfx1151 with your arch) |
| 24 | +# Build for your GPU (replace gfx1151 with your arch) |
28 | 25 | cmake -B build -DBUILD_HIP=ON -DBNB_ROCM_ARCH="gfx1151" -DROCM_VERSION="713" |
29 | 26 | cmake --build build -j$(nproc) |
30 | | - |
31 | | -# The .so is placed directly into bitsandbytes/ |
32 | | -# Verify: |
33 | 27 | python -c "import bitsandbytes; print(bitsandbytes.__version__)" |
34 | 28 | ``` |
35 | 29 |
|
36 | | -## Testing Steps |
| 30 | +## Test Models |
37 | 31 |
|
38 | | -### 1. Correctness Tests (required -- catch regressions) |
| 32 | +All benchmarks use these 3 models: |
| 33 | +- **Mistral-7B**: `unsloth/mistral-7b-instruct-v0.3-bnb-4bit` (~4 GB) |
| 34 | +- **Llama-8B**: `unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit` (~5.5 GB) |
| 35 | +- **Qwen3.5-9B**: `Qwen/Qwen3.5-9B` with `quantization='bitsandbytes'` (~6 GB) |
39 | 36 |
|
40 | | -```bash |
41 | | -# Core gemv_4bit tests (60 tests, ~5s) |
42 | | -python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v |
| 37 | +--- |
43 | 38 |
|
44 | | -# Full Linear4bit tests (243 tests, ~60s) |
45 | | -python -m pytest tests/test_linear4bit.py -v |
| 39 | +## Phase A: Baseline (no kernel changes) |
46 | 40 |
|
47 | | -# Full functional tests for gemv_4bit (192 tests, ~20s) |
48 | | -# Note: ~24 fp32-specific threshold tests may fail on AMD due to |
49 | | -# NVIDIA-calibrated thresholds -- this is pre-existing, not a regression. |
50 | | -python -m pytest tests/test_functional.py -k "test_gemv_4bit" -v |
51 | | -``` |
| 41 | +Test upstream `main` to establish baseline numbers on your GPU. |
52 | 42 |
|
53 | | -### 2. Kernel Microbenchmark (bandwidth measurement) |
| 43 | +```bash |
| 44 | +# Revert to upstream kernel |
| 45 | +git stash push -m "optimized" -- csrc/kernels.cu csrc/kernels.cuh csrc/ops.cu \ |
| 46 | + bitsandbytes/backends/cuda/ops.py bitsandbytes/_ops.py bitsandbytes/autograd/_functions.py |
| 47 | +cmake --build build -j$(nproc) |
| 48 | +``` |
54 | 49 |
|
55 | | -This measures the raw kernel performance at 70B MLP dimensions (N=28672, K=8192): |
| 50 | +### A1. Correctness |
| 51 | +```bash |
| 52 | +python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v |
| 53 | +python -m pytest tests/test_linear4bit.py -v |
| 54 | +``` |
56 | 55 |
|
| 56 | +### A2. Kernel microbenchmark |
57 | 57 | ```bash |
58 | 58 | python bench_quick.py |
59 | | -# Expected output: "<time> µs | <BW> GB/s | <pct>% peak" |
| 59 | +# Record: baseline_us, baseline_bw |
60 | 60 | ``` |
61 | 61 |
|
62 | | -Run this on both `main` (baseline) and this branch (optimized) to measure the kernel-level speedup. |
| 62 | +### A3. vLLM serving (all requests go through dequant+GEMM for M>1) |
| 63 | +```bash |
| 64 | +export PYTHONPATH=<venv>/lib/python3.12/site-packages/_rocm_sdk_devel/share/amd_smi |
| 65 | +export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE |
| 66 | +export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 |
63 | 67 |
|
64 | | -For detailed bandwidth + profiler output: |
| 68 | +python bench_vllm_sweep.py --limit 1 |
| 69 | +# Record throughput at reqs=1,2,4,8,16,24,32 |
| 70 | +``` |
65 | 71 |
|
| 72 | +### A4. Restore optimized branch |
66 | 73 | ```bash |
67 | | -python bench_gemv_4bit.py |
| 74 | +git stash pop |
| 75 | +cmake --build build -j$(nproc) |
68 | 76 | ``` |
69 | 77 |
|
70 | | -### 3. A/B Comparison: Baseline vs Optimized |
71 | | - |
72 | | -```bash |
73 | | -# === Step 1: Benchmark optimized (this branch) === |
74 | | -python bench_quick.py |
75 | | -# Record: time_optimized, bw_optimized |
| 78 | +--- |
76 | 79 |
|
77 | | -# === Step 2: Benchmark baseline (upstream main) === |
78 | | -git stash push -m "temp" -- csrc/kernels.cu csrc/kernels.cuh csrc/ops.cu \ |
79 | | - bitsandbytes/backends/cuda/ops.py bitsandbytes/_ops.py bitsandbytes/autograd/_functions.py |
80 | | -cmake --build build -j$(nproc) |
81 | | -python bench_quick.py |
82 | | -# Record: time_baseline, bw_baseline |
| 80 | +## Phase B: Kernel optimizations only (M=1 fused, original dispatch) |
83 | 81 |
|
84 | | -# === Step 3: Restore optimized === |
85 | | -git stash pop |
86 | | -cmake --build build -j$(nproc) |
| 82 | +Test the kernel-level improvements without the M>1 dispatch change. |
87 | 83 |
|
88 | | -# === Step 4: Compare === |
89 | | -# Speedup = time_baseline / time_optimized |
| 84 | +```bash |
| 85 | +# Set limit to 1 (same behavior as upstream: only M=1 uses fused kernel) |
| 86 | +sed -i "s/FUSED_4BIT_M_LIMIT = [0-9]*/FUSED_4BIT_M_LIMIT = 1/" bitsandbytes/autograd/_functions.py |
90 | 87 | ``` |
91 | 88 |
|
92 | | -### 4. End-to-End Model Benchmarks |
| 89 | +### B1. Correctness |
| 90 | +```bash |
| 91 | +python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v |
| 92 | +python -m pytest tests/test_linear4bit.py -v |
| 93 | +``` |
93 | 94 |
|
94 | | -Test with pre-quantized 4-bit models via Unsloth: |
| 95 | +### B2. Kernel microbenchmark |
| 96 | +```bash |
| 97 | +python bench_quick.py |
| 98 | +# Record: optimized_us, optimized_bw |
| 99 | +# Compare: speedup = baseline_us / optimized_us |
| 100 | +``` |
95 | 101 |
|
| 102 | +### B3. Single-user decode (HuggingFace) |
96 | 103 | ```bash |
97 | 104 | export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 |
98 | 105 |
|
99 | | -# 7B model (~4 GB VRAM) |
100 | 106 | python bench_e2e.py --model unsloth/mistral-7b-instruct-v0.3-bnb-4bit \ |
101 | | - --method unsloth --prompt-tokens 128 512 --max-new-tokens 128 --runs 3 |
102 | | - |
103 | | -# 8B model (~5.5 GB VRAM) |
| 107 | + --method unsloth --prompt-tokens 128 --max-new-tokens 128 --runs 3 |
104 | 108 | python bench_e2e.py --model unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit \ |
105 | | - --method unsloth --prompt-tokens 128 512 --max-new-tokens 128 --runs 3 |
106 | | - |
107 | | -# 14B model (~10 GB VRAM) |
108 | | -python bench_e2e.py --model unsloth/phi-4-unsloth-bnb-4bit \ |
109 | | - --method unsloth --prompt-tokens 128 512 --max-new-tokens 128 --runs 3 |
| 109 | + --method unsloth --prompt-tokens 128 --max-new-tokens 128 --runs 3 |
110 | 110 | ``` |
111 | 111 |
|
112 | | -Or via HuggingFace (quantizes at load time): |
113 | | - |
| 112 | +### B4. vLLM serving (M>1 still goes through dequant+GEMM) |
114 | 113 | ```bash |
115 | | -python bench_e2e.py --model meta-llama/Llama-3.1-8B-Instruct \ |
116 | | - --method hf --prompt-tokens 128 512 --max-new-tokens 128 --runs 3 |
| 114 | +python bench_vllm_sweep.py --limit 1 |
| 115 | +# Should match Phase A3 results (same dispatch behavior, faster M=1 kernel) |
117 | 116 | ``` |
118 | 117 |
|
119 | | -### 5. Multi-batch Fused Path Validation |
| 118 | +--- |
120 | 119 |
|
121 | | -Verify the fused M>1 path works correctly and is faster than dequant+GEMM: |
| 120 | +## Phase C: Full optimization (M<=16 fused dispatch) |
122 | 121 |
|
123 | | -```python |
124 | | -import torch, time |
125 | | -import bitsandbytes as bnb |
126 | | -import bitsandbytes.functional as F |
| 122 | +Test the complete optimization including the M>1 fused path. |
127 | 123 |
|
128 | | -N, K = 28672, 8192 |
129 | | -w_fp = torch.randn(N, K, dtype=torch.bfloat16, device='cuda') |
130 | | -w_4bit, qs = F.quantize_4bit(w_fp, quant_type='nf4') |
| 124 | +```bash |
| 125 | +# Set limit to 16 |
| 126 | +sed -i "s/FUSED_4BIT_M_LIMIT = [0-9]*/FUSED_4BIT_M_LIMIT = 16/" bitsandbytes/autograd/_functions.py |
| 127 | +``` |
131 | 128 |
|
132 | | -for M in [1, 2, 4, 8, 16]: |
133 | | - x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') |
134 | | - with torch.inference_mode(): |
135 | | - for _ in range(10): |
136 | | - bnb.matmul_4bit(x, w_4bit.t(), qs) |
137 | | - torch.cuda.synchronize() |
138 | | - t0 = time.perf_counter() |
139 | | - for _ in range(30): |
140 | | - bnb.matmul_4bit(x, w_4bit.t(), qs) |
141 | | - torch.cuda.synchronize() |
142 | | - us = (time.perf_counter() - t0) / 30 * 1e6 |
143 | | - print(f"M={M}: {us:.0f} us") |
| 129 | +### C1. Correctness |
| 130 | +```bash |
| 131 | +python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v |
| 132 | +python -m pytest tests/test_linear4bit.py -v |
144 | 133 | ``` |
145 | 134 |
|
146 | | -### 6. Fused vs Split Crossover Analysis |
| 135 | +### C2. Kernel microbenchmark (should match Phase B) |
| 136 | +```bash |
| 137 | +python bench_quick.py |
| 138 | +``` |
147 | 139 |
|
| 140 | +### C3. Fused vs split crossover |
148 | 141 | ```bash |
149 | 142 | python bench_crossover.py |
150 | | -# Outputs: per-weight-size crossover M, CSV for comparison |
| 143 | +# Verify: fused is faster than split for M<=16 on your GPU |
151 | 144 | ``` |
152 | 145 |
|
153 | | -### 7. vLLM Serving Benchmark (critical for M>1 validation) |
154 | | - |
155 | | -The fused M>1 path has its biggest impact in vLLM, where concurrent requests produce M=num_active_requests at each decode step. Test with: |
156 | | - |
| 146 | +### C4. vLLM serving with M<=16 fused |
157 | 147 | ```bash |
158 | | -export PYTHONPATH=<path_to_venv>/lib/python3.12/site-packages/_rocm_sdk_devel/share/amd_smi |
159 | | -export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE |
160 | | -export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 |
161 | | - |
162 | | -# Quick sweep: compares L=1 (baseline) vs L=16 (optimized) across concurrency levels |
163 | | -python bench_vllm_sweep.py --limit 1 |
164 | 148 | python bench_vllm_sweep.py --limit 16 |
| 149 | +# Compare against Phase A3/B4: expect 2-5x improvement at reqs=2-8 |
| 150 | +``` |
165 | 151 |
|
166 | | -# Full benchmark with specific models |
167 | | -python bench_vllm_full.py --limit 16 --models mistral7b --max-tokens 256 --concurrency 1 2 4 8 16 |
| 152 | +### C5. vLLM regression check at high concurrency (M>16) |
| 153 | +Test concurrency levels above the M=16 threshold to ensure no regressions: |
| 154 | +```bash |
| 155 | +python bench_vllm_full.py --limit 16 --models mistral7b \ |
| 156 | + --max-tokens 128 --eager --concurrency 1 2 4 8 16 24 32 48 64 |
168 | 157 | ``` |
| 158 | +At reqs>16, the split path is used. Verify throughput at reqs=24,32,48,64 matches Phase A3. |
169 | 159 |
|
170 | | -To A/B test the dispatch threshold, edit `FUSED_4BIT_M_LIMIT` in `bitsandbytes/autograd/_functions.py` before each run -- vLLM forks worker processes that read the constant at import time. |
| 160 | +### C6. Multi-model vLLM validation |
| 161 | +```bash |
| 162 | +# Test all 3 models at key concurrency levels |
| 163 | +for MODEL in mistral7b llama8b; do |
| 164 | + python bench_vllm_full.py --limit 16 --models $MODEL \ |
| 165 | + --max-tokens 128 --eager --concurrency 1 2 4 8 16 24 32 |
| 166 | +done |
| 167 | + |
| 168 | +# Qwen3.5-9B (quantized at load time, not pre-quantized) |
| 169 | +python bench_vllm_full.py --limit 16 --models qwen27b \ |
| 170 | + --max-tokens 128 --eager --concurrency 1 2 4 8 16 24 32 |
| 171 | +``` |
171 | 172 |
|
172 | | -## Expected Results by GPU |
| 173 | +Note: for Qwen3.5-9B in `bench_vllm_full.py`, edit the model registry to use `Qwen/Qwen3.5-9B` if needed. |
173 | 174 |
|
174 | | -Results will vary by architecture. Reference numbers from gfx1151 (Radeon 8060S, 40 CUs, 210 GB/s peak): |
| 175 | +--- |
175 | 176 |
|
176 | | -### Kernel Microbenchmark |
| 177 | +## Expected Results (gfx1151 reference) |
177 | 178 |
|
178 | | -| Metric | Baseline | Optimized | Speedup | |
179 | | -|--------|----------|-----------|---------| |
180 | | -| Kernel (M=1, 70B dims) | 1133 us | 740 us | 1.53x | |
181 | | -| Kernel BW (incl absmax) | 117 GB/s | 178 GB/s | 85% peak | |
| 179 | +### Kernel Microbenchmark (70B MLP dims) |
182 | 180 |
|
183 | | -### HuggingFace Single-User Decode |
| 181 | +| Phase | Time | BW | Speedup | |
| 182 | +|-------|------|------|---------| |
| 183 | +| A (baseline) | 1133 us | 117 GB/s | -- | |
| 184 | +| B (kernel opt) | 740 us | 178 GB/s | 1.53x | |
| 185 | +| C (same kernel) | 740 us | 178 GB/s | 1.53x | |
184 | 186 |
|
185 | | -| Model | Baseline | Optimized | Speedup | |
186 | | -|-------|----------|-----------|---------| |
187 | | -| Mistral-7B | 17.7 tok/s | 27.9 tok/s | 1.53x | |
188 | | -| Phi-4 (14B) | 10.5 tok/s | 15.0 tok/s | 1.43x | |
| 187 | +### vLLM Serving (Mistral-7B, tok/s) |
189 | 188 |
|
190 | | -### vLLM Serving (Mistral-7B, 256 tokens, compiled mode) |
| 189 | +| Reqs | Phase A (baseline) | Phase B (L=1) | Phase C (L=16) | C vs A | |
| 190 | +|------|-------------------|---------------|----------------|--------| |
| 191 | +| 1 | ~34 | ~34 | ~36 | 1.06x | |
| 192 | +| 2 | ~10 | ~10 | **~54** | **5.2x** | |
| 193 | +| 4 | ~20 | ~20 | **~69** | **3.4x** | |
| 194 | +| 8 | ~40 | ~40 | **~76** | **1.9x** | |
| 195 | +| 16 | ~80 | ~80 | ~80 | 1.0x | |
| 196 | +| 24 | ~112 | ~112 | ~112 | 1.0x | |
| 197 | +| 32 | ~149 | ~149 | ~149 | 1.0x | |
191 | 198 |
|
192 | | -| Concurrent Reqs | L=1 (baseline) | L=16 (optimized) | Speedup | |
193 | | -|-----------------|----------------|------------------|---------| |
194 | | -| 1 | 35.9 tok/s | 36.3 tok/s | 1.01x | |
195 | | -| 2 | 10.4 tok/s | 54.5 tok/s | **5.24x** | |
196 | | -| 4 | 20.7 tok/s | 69.4 tok/s | **3.35x** | |
197 | | -| 8 | 40.7 tok/s | 76.4 tok/s | **1.88x** | |
198 | | -| 16 | 80.5 tok/s | 80.2 tok/s | 1.00x | |
| 199 | +Phase B and Phase A should produce identical results at reqs>1 (same dispatch). |
| 200 | +Phase C should match Phase A/B at reqs>16 (split path used for M>16). |
199 | 201 |
|
200 | | -Validated across Mistral-7B, Llama-8B, and Qwen3.5-9B with zero regressions. |
| 202 | +--- |
201 | 203 |
|
202 | | -## GPU Architecture Notes |
| 204 | +## Environment Variables |
203 | 205 |
|
204 | | -The fused dispatch uses `FUSED_4BIT_M_LIMIT = 16`. During vLLM continuous batching, each decode step calls `matmul_4bit` with `M = num_concurrent_requests`. For M<=16, the fused kernel avoids writing/reading a 469 MB bf16 intermediate, giving 2-5x throughput improvement at typical serving concurrency. |
| 206 | +```bash |
| 207 | +# Required for ROCm attention kernels |
| 208 | +export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 |
205 | 209 |
|
206 | | -| Architecture | GPU Example | Notes | |
207 | | -|---|---|---| |
208 | | -| gfx90a | MI210 | CDNA, wave64, uses 128 threads | |
209 | | -| gfx942 | MI300X | CDNA, wave64 | |
210 | | -| gfx1100 | RX 7900 XTX | RDNA3, wave32 | |
211 | | -| gfx1101 | RX 7800 XT | RDNA3, wave32 | |
212 | | -| gfx1151 | Radeon 8060S | RDNA3.5, wave32, validated | |
213 | | -| gfx1200 | RX 9070 XT | RDNA4, wave32 | |
214 | | -| gfx1201 | RX 9070 | RDNA4, wave32 | |
| 210 | +# Required for vLLM ROCm platform detection |
| 211 | +export PYTHONPATH=<venv>/lib/python3.12/site-packages/_rocm_sdk_devel/share/amd_smi |
| 212 | +export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE |
215 | 213 |
|
216 | | -Run `bench_crossover.py` and `bench_vllm_sweep.py` on your GPU to verify the threshold is appropriate. |
| 214 | +# Optional: force offline model loading (skip HF API calls) |
| 215 | +export HF_HUB_OFFLINE=1 |
| 216 | +``` |
217 | 217 |
|
218 | 218 | ## Reporting Results |
219 | 219 |
|
220 | | -When reporting results, please include: |
221 | | -1. `rocminfo | grep "Name:" | head -5` (GPU name and arch) |
222 | | -2. `python -c "import torch; print(torch.__version__)"` (PyTorch version) |
223 | | -3. `hipcc --version | head -1` (ROCm compiler version) |
224 | | -4. Output of `bench_quick.py` for both baseline and optimized |
225 | | -5. Output of `pytest tests/test_ops.py -k test_gemv_4bit` (pass/fail count) |
226 | | -6. Output of `pytest tests/test_linear4bit.py` (pass/fail count) |
227 | | -7. Output of `bench_vllm_sweep.py --limit 16` (vLLM throughput at various concurrency) |
228 | | -8. Any regressions observed (fused slower than split, test failures) |
| 220 | +Please include: |
| 221 | +1. GPU: `rocminfo | grep "Name:" | head -5` |
| 222 | +2. Software: `python -c "import torch; print(torch.__version__)"` |
| 223 | +3. Phase A: `bench_quick.py` output + `bench_vllm_sweep.py --limit 1` output |
| 224 | +4. Phase B: `bench_quick.py` output |
| 225 | +5. Phase C: `bench_vllm_sweep.py --limit 16` output + `bench_vllm_full.py` output at reqs=24,32 |
| 226 | +6. Correctness: `pytest tests/test_ops.py -k test_gemv_4bit` and `pytest tests/test_linear4bit.py` |
| 227 | +7. Any regressions: fused slower than split, test failures, or throughput drops at reqs>16 |
0 commit comments