Skip to content

Commit eccb81f

Browse files
committed
Restructure testing for 3-phase A/B/C comparison, test beyond M=16
Testing instructions now cover: - Phase A: Baseline (upstream kernel, no changes) - Phase B: Kernel optimizations only (M=1 fused, L=1) - Phase C: Full optimization (M<=16 fused, L=16) bench_vllm_sweep.py updated: - Default concurrency sweep includes 24, 32, 48, 64 (beyond M=16 threshold) - Configurable model (mistral7b, llama8b, qwen9b) - Shows fused/split path per concurrency level Models limited to Mistral-7B, Llama-8B, Qwen3.5-9B for consistency. Made-with: Cursor
1 parent 74b9920 commit eccb81f

3 files changed

Lines changed: 190 additions & 184 deletions

File tree

TESTING_OPTIMIZATIONS.md

Lines changed: 142 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,228 +1,227 @@
11
# Testing ROCm 4-bit Kernel Optimizations
22

3-
This document describes how to test the RDNA/CDNA kernel optimizations for `kgemm_4bit_inference_naive` on different AMD GPU hardware.
3+
This document describes how to test the RDNA/CDNA kernel optimizations on different AMD GPU hardware. The testing is structured in 3 phases to isolate the impact of each change.
44

55
## What Changed
66

7-
1. **Kernel optimizations** (`csrc/kernels.cu`): Float compute path, fully unrolled dequant+FMA, replicated quant_map for bank-conflict reduction, B data prefetching, `__launch_bounds__` guard
8-
2. **Fused multi-batch GEMM** (`csrc/kernels.cu`): N-loop wrapping K-loop enables fused 4-bit matmul for M>1 (small batch sizes up to M=16)
9-
3. **Fixed dispatch threshold** (`bitsandbytes/autograd/_functions.py`): `FUSED_4BIT_M_LIMIT = 16` routes M<=16 through fused kernel instead of dequant+GEMM fallback. Critical for vLLM serving with concurrent requests.
10-
4. **Multi-row Python backend** (`bitsandbytes/backends/cuda/ops.py`, `bitsandbytes/_ops.py`): `gemv_4bit` accepts A with multiple rows
7+
1. **Kernel optimizations** (`csrc/kernels.cu`): Float compute path, fully unrolled dequant+FMA, replicated quant_map, B data prefetching, `__launch_bounds__`
8+
2. **Fused multi-batch GEMM** (`csrc/kernels.cu`): N-loop enables fused 4-bit matmul for M>1
9+
3. **Dispatch threshold** (`bitsandbytes/autograd/_functions.py`): `FUSED_4BIT_M_LIMIT = 16` routes M<=16 through fused kernel instead of dequant+GEMM fallback
10+
4. **Multi-row backend** (`bitsandbytes/backends/cuda/ops.py`, `bitsandbytes/_ops.py`): `gemv_4bit` accepts A with multiple rows
1111

1212
## Prerequisites
1313

1414
```bash
15-
# ROCm 7.x with HIP support
16-
# PyTorch with ROCm (2.9+)
17-
# Python 3.10+
18-
15+
# ROCm 7.x with HIP support, PyTorch with ROCm (2.9+), Python 3.10+
1916
pip install pytest einops scipy transformers accelerate
20-
# Optional for e2e model benchmarks:
21-
pip install unsloth
17+
pip install unsloth # for e2e model benchmarks
18+
# For vLLM testing: install vLLM with ROCm support
2219
```
2320

2421
## Build
2522

2623
```bash
27-
# Build for your specific GPU (replace gfx1151 with your arch)
24+
# Build for your GPU (replace gfx1151 with your arch)
2825
cmake -B build -DBUILD_HIP=ON -DBNB_ROCM_ARCH="gfx1151" -DROCM_VERSION="713"
2926
cmake --build build -j$(nproc)
30-
31-
# The .so is placed directly into bitsandbytes/
32-
# Verify:
3327
python -c "import bitsandbytes; print(bitsandbytes.__version__)"
3428
```
3529

36-
## Testing Steps
30+
## Test Models
3731

38-
### 1. Correctness Tests (required -- catch regressions)
32+
All benchmarks use these 3 models:
33+
- **Mistral-7B**: `unsloth/mistral-7b-instruct-v0.3-bnb-4bit` (~4 GB)
34+
- **Llama-8B**: `unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit` (~5.5 GB)
35+
- **Qwen3.5-9B**: `Qwen/Qwen3.5-9B` with `quantization='bitsandbytes'` (~6 GB)
3936

40-
```bash
41-
# Core gemv_4bit tests (60 tests, ~5s)
42-
python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v
37+
---
4338

44-
# Full Linear4bit tests (243 tests, ~60s)
45-
python -m pytest tests/test_linear4bit.py -v
39+
## Phase A: Baseline (no kernel changes)
4640

47-
# Full functional tests for gemv_4bit (192 tests, ~20s)
48-
# Note: ~24 fp32-specific threshold tests may fail on AMD due to
49-
# NVIDIA-calibrated thresholds -- this is pre-existing, not a regression.
50-
python -m pytest tests/test_functional.py -k "test_gemv_4bit" -v
51-
```
41+
Test upstream `main` to establish baseline numbers on your GPU.
5242

53-
### 2. Kernel Microbenchmark (bandwidth measurement)
43+
```bash
44+
# Revert to upstream kernel
45+
git stash push -m "optimized" -- csrc/kernels.cu csrc/kernels.cuh csrc/ops.cu \
46+
bitsandbytes/backends/cuda/ops.py bitsandbytes/_ops.py bitsandbytes/autograd/_functions.py
47+
cmake --build build -j$(nproc)
48+
```
5449

55-
This measures the raw kernel performance at 70B MLP dimensions (N=28672, K=8192):
50+
### A1. Correctness
51+
```bash
52+
python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v
53+
python -m pytest tests/test_linear4bit.py -v
54+
```
5655

56+
### A2. Kernel microbenchmark
5757
```bash
5858
python bench_quick.py
59-
# Expected output: "<time> µs | <BW> GB/s | <pct>% peak"
59+
# Record: baseline_us, baseline_bw
6060
```
6161

62-
Run this on both `main` (baseline) and this branch (optimized) to measure the kernel-level speedup.
62+
### A3. vLLM serving (all requests go through dequant+GEMM for M>1)
63+
```bash
64+
export PYTHONPATH=<venv>/lib/python3.12/site-packages/_rocm_sdk_devel/share/amd_smi
65+
export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE
66+
export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
6367

64-
For detailed bandwidth + profiler output:
68+
python bench_vllm_sweep.py --limit 1
69+
# Record throughput at reqs=1,2,4,8,16,24,32
70+
```
6571

72+
### A4. Restore optimized branch
6673
```bash
67-
python bench_gemv_4bit.py
74+
git stash pop
75+
cmake --build build -j$(nproc)
6876
```
6977

70-
### 3. A/B Comparison: Baseline vs Optimized
71-
72-
```bash
73-
# === Step 1: Benchmark optimized (this branch) ===
74-
python bench_quick.py
75-
# Record: time_optimized, bw_optimized
78+
---
7679

77-
# === Step 2: Benchmark baseline (upstream main) ===
78-
git stash push -m "temp" -- csrc/kernels.cu csrc/kernels.cuh csrc/ops.cu \
79-
bitsandbytes/backends/cuda/ops.py bitsandbytes/_ops.py bitsandbytes/autograd/_functions.py
80-
cmake --build build -j$(nproc)
81-
python bench_quick.py
82-
# Record: time_baseline, bw_baseline
80+
## Phase B: Kernel optimizations only (M=1 fused, original dispatch)
8381

84-
# === Step 3: Restore optimized ===
85-
git stash pop
86-
cmake --build build -j$(nproc)
82+
Test the kernel-level improvements without the M>1 dispatch change.
8783

88-
# === Step 4: Compare ===
89-
# Speedup = time_baseline / time_optimized
84+
```bash
85+
# Set limit to 1 (same behavior as upstream: only M=1 uses fused kernel)
86+
sed -i "s/FUSED_4BIT_M_LIMIT = [0-9]*/FUSED_4BIT_M_LIMIT = 1/" bitsandbytes/autograd/_functions.py
9087
```
9188

92-
### 4. End-to-End Model Benchmarks
89+
### B1. Correctness
90+
```bash
91+
python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v
92+
python -m pytest tests/test_linear4bit.py -v
93+
```
9394

94-
Test with pre-quantized 4-bit models via Unsloth:
95+
### B2. Kernel microbenchmark
96+
```bash
97+
python bench_quick.py
98+
# Record: optimized_us, optimized_bw
99+
# Compare: speedup = baseline_us / optimized_us
100+
```
95101

102+
### B3. Single-user decode (HuggingFace)
96103
```bash
97104
export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
98105

99-
# 7B model (~4 GB VRAM)
100106
python bench_e2e.py --model unsloth/mistral-7b-instruct-v0.3-bnb-4bit \
101-
--method unsloth --prompt-tokens 128 512 --max-new-tokens 128 --runs 3
102-
103-
# 8B model (~5.5 GB VRAM)
107+
--method unsloth --prompt-tokens 128 --max-new-tokens 128 --runs 3
104108
python bench_e2e.py --model unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit \
105-
--method unsloth --prompt-tokens 128 512 --max-new-tokens 128 --runs 3
106-
107-
# 14B model (~10 GB VRAM)
108-
python bench_e2e.py --model unsloth/phi-4-unsloth-bnb-4bit \
109-
--method unsloth --prompt-tokens 128 512 --max-new-tokens 128 --runs 3
109+
--method unsloth --prompt-tokens 128 --max-new-tokens 128 --runs 3
110110
```
111111

112-
Or via HuggingFace (quantizes at load time):
113-
112+
### B4. vLLM serving (M>1 still goes through dequant+GEMM)
114113
```bash
115-
python bench_e2e.py --model meta-llama/Llama-3.1-8B-Instruct \
116-
--method hf --prompt-tokens 128 512 --max-new-tokens 128 --runs 3
114+
python bench_vllm_sweep.py --limit 1
115+
# Should match Phase A3 results (same dispatch behavior, faster M=1 kernel)
117116
```
118117

119-
### 5. Multi-batch Fused Path Validation
118+
---
120119

121-
Verify the fused M>1 path works correctly and is faster than dequant+GEMM:
120+
## Phase C: Full optimization (M<=16 fused dispatch)
122121

123-
```python
124-
import torch, time
125-
import bitsandbytes as bnb
126-
import bitsandbytes.functional as F
122+
Test the complete optimization including the M>1 fused path.
127123

128-
N, K = 28672, 8192
129-
w_fp = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')
130-
w_4bit, qs = F.quantize_4bit(w_fp, quant_type='nf4')
124+
```bash
125+
# Set limit to 16
126+
sed -i "s/FUSED_4BIT_M_LIMIT = [0-9]*/FUSED_4BIT_M_LIMIT = 16/" bitsandbytes/autograd/_functions.py
127+
```
131128

132-
for M in [1, 2, 4, 8, 16]:
133-
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
134-
with torch.inference_mode():
135-
for _ in range(10):
136-
bnb.matmul_4bit(x, w_4bit.t(), qs)
137-
torch.cuda.synchronize()
138-
t0 = time.perf_counter()
139-
for _ in range(30):
140-
bnb.matmul_4bit(x, w_4bit.t(), qs)
141-
torch.cuda.synchronize()
142-
us = (time.perf_counter() - t0) / 30 * 1e6
143-
print(f"M={M}: {us:.0f} us")
129+
### C1. Correctness
130+
```bash
131+
python -m pytest tests/test_ops.py -k "test_gemv_4bit" -v
132+
python -m pytest tests/test_linear4bit.py -v
144133
```
145134

146-
### 6. Fused vs Split Crossover Analysis
135+
### C2. Kernel microbenchmark (should match Phase B)
136+
```bash
137+
python bench_quick.py
138+
```
147139

140+
### C3. Fused vs split crossover
148141
```bash
149142
python bench_crossover.py
150-
# Outputs: per-weight-size crossover M, CSV for comparison
143+
# Verify: fused is faster than split for M<=16 on your GPU
151144
```
152145

153-
### 7. vLLM Serving Benchmark (critical for M>1 validation)
154-
155-
The fused M>1 path has its biggest impact in vLLM, where concurrent requests produce M=num_active_requests at each decode step. Test with:
156-
146+
### C4. vLLM serving with M<=16 fused
157147
```bash
158-
export PYTHONPATH=<path_to_venv>/lib/python3.12/site-packages/_rocm_sdk_devel/share/amd_smi
159-
export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE
160-
export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
161-
162-
# Quick sweep: compares L=1 (baseline) vs L=16 (optimized) across concurrency levels
163-
python bench_vllm_sweep.py --limit 1
164148
python bench_vllm_sweep.py --limit 16
149+
# Compare against Phase A3/B4: expect 2-5x improvement at reqs=2-8
150+
```
165151

166-
# Full benchmark with specific models
167-
python bench_vllm_full.py --limit 16 --models mistral7b --max-tokens 256 --concurrency 1 2 4 8 16
152+
### C5. vLLM regression check at high concurrency (M>16)
153+
Test concurrency levels above the M=16 threshold to ensure no regressions:
154+
```bash
155+
python bench_vllm_full.py --limit 16 --models mistral7b \
156+
--max-tokens 128 --eager --concurrency 1 2 4 8 16 24 32 48 64
168157
```
158+
At reqs>16, the split path is used. Verify throughput at reqs=24,32,48,64 matches Phase A3.
169159

170-
To A/B test the dispatch threshold, edit `FUSED_4BIT_M_LIMIT` in `bitsandbytes/autograd/_functions.py` before each run -- vLLM forks worker processes that read the constant at import time.
160+
### C6. Multi-model vLLM validation
161+
```bash
162+
# Test all 3 models at key concurrency levels
163+
for MODEL in mistral7b llama8b; do
164+
python bench_vllm_full.py --limit 16 --models $MODEL \
165+
--max-tokens 128 --eager --concurrency 1 2 4 8 16 24 32
166+
done
167+
168+
# Qwen3.5-9B (quantized at load time, not pre-quantized)
169+
python bench_vllm_full.py --limit 16 --models qwen27b \
170+
--max-tokens 128 --eager --concurrency 1 2 4 8 16 24 32
171+
```
171172

172-
## Expected Results by GPU
173+
Note: for Qwen3.5-9B in `bench_vllm_full.py`, edit the model registry to use `Qwen/Qwen3.5-9B` if needed.
173174

174-
Results will vary by architecture. Reference numbers from gfx1151 (Radeon 8060S, 40 CUs, 210 GB/s peak):
175+
---
175176

176-
### Kernel Microbenchmark
177+
## Expected Results (gfx1151 reference)
177178

178-
| Metric | Baseline | Optimized | Speedup |
179-
|--------|----------|-----------|---------|
180-
| Kernel (M=1, 70B dims) | 1133 us | 740 us | 1.53x |
181-
| Kernel BW (incl absmax) | 117 GB/s | 178 GB/s | 85% peak |
179+
### Kernel Microbenchmark (70B MLP dims)
182180

183-
### HuggingFace Single-User Decode
181+
| Phase | Time | BW | Speedup |
182+
|-------|------|------|---------|
183+
| A (baseline) | 1133 us | 117 GB/s | -- |
184+
| B (kernel opt) | 740 us | 178 GB/s | 1.53x |
185+
| C (same kernel) | 740 us | 178 GB/s | 1.53x |
184186

185-
| Model | Baseline | Optimized | Speedup |
186-
|-------|----------|-----------|---------|
187-
| Mistral-7B | 17.7 tok/s | 27.9 tok/s | 1.53x |
188-
| Phi-4 (14B) | 10.5 tok/s | 15.0 tok/s | 1.43x |
187+
### vLLM Serving (Mistral-7B, tok/s)
189188

190-
### vLLM Serving (Mistral-7B, 256 tokens, compiled mode)
189+
| Reqs | Phase A (baseline) | Phase B (L=1) | Phase C (L=16) | C vs A |
190+
|------|-------------------|---------------|----------------|--------|
191+
| 1 | ~34 | ~34 | ~36 | 1.06x |
192+
| 2 | ~10 | ~10 | **~54** | **5.2x** |
193+
| 4 | ~20 | ~20 | **~69** | **3.4x** |
194+
| 8 | ~40 | ~40 | **~76** | **1.9x** |
195+
| 16 | ~80 | ~80 | ~80 | 1.0x |
196+
| 24 | ~112 | ~112 | ~112 | 1.0x |
197+
| 32 | ~149 | ~149 | ~149 | 1.0x |
191198

192-
| Concurrent Reqs | L=1 (baseline) | L=16 (optimized) | Speedup |
193-
|-----------------|----------------|------------------|---------|
194-
| 1 | 35.9 tok/s | 36.3 tok/s | 1.01x |
195-
| 2 | 10.4 tok/s | 54.5 tok/s | **5.24x** |
196-
| 4 | 20.7 tok/s | 69.4 tok/s | **3.35x** |
197-
| 8 | 40.7 tok/s | 76.4 tok/s | **1.88x** |
198-
| 16 | 80.5 tok/s | 80.2 tok/s | 1.00x |
199+
Phase B and Phase A should produce identical results at reqs>1 (same dispatch).
200+
Phase C should match Phase A/B at reqs>16 (split path used for M>16).
199201

200-
Validated across Mistral-7B, Llama-8B, and Qwen3.5-9B with zero regressions.
202+
---
201203

202-
## GPU Architecture Notes
204+
## Environment Variables
203205

204-
The fused dispatch uses `FUSED_4BIT_M_LIMIT = 16`. During vLLM continuous batching, each decode step calls `matmul_4bit` with `M = num_concurrent_requests`. For M<=16, the fused kernel avoids writing/reading a 469 MB bf16 intermediate, giving 2-5x throughput improvement at typical serving concurrency.
206+
```bash
207+
# Required for ROCm attention kernels
208+
export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
205209

206-
| Architecture | GPU Example | Notes |
207-
|---|---|---|
208-
| gfx90a | MI210 | CDNA, wave64, uses 128 threads |
209-
| gfx942 | MI300X | CDNA, wave64 |
210-
| gfx1100 | RX 7900 XTX | RDNA3, wave32 |
211-
| gfx1101 | RX 7800 XT | RDNA3, wave32 |
212-
| gfx1151 | Radeon 8060S | RDNA3.5, wave32, validated |
213-
| gfx1200 | RX 9070 XT | RDNA4, wave32 |
214-
| gfx1201 | RX 9070 | RDNA4, wave32 |
210+
# Required for vLLM ROCm platform detection
211+
export PYTHONPATH=<venv>/lib/python3.12/site-packages/_rocm_sdk_devel/share/amd_smi
212+
export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE
215213

216-
Run `bench_crossover.py` and `bench_vllm_sweep.py` on your GPU to verify the threshold is appropriate.
214+
# Optional: force offline model loading (skip HF API calls)
215+
export HF_HUB_OFFLINE=1
216+
```
217217

218218
## Reporting Results
219219

220-
When reporting results, please include:
221-
1. `rocminfo | grep "Name:" | head -5` (GPU name and arch)
222-
2. `python -c "import torch; print(torch.__version__)"` (PyTorch version)
223-
3. `hipcc --version | head -1` (ROCm compiler version)
224-
4. Output of `bench_quick.py` for both baseline and optimized
225-
5. Output of `pytest tests/test_ops.py -k test_gemv_4bit` (pass/fail count)
226-
6. Output of `pytest tests/test_linear4bit.py` (pass/fail count)
227-
7. Output of `bench_vllm_sweep.py --limit 16` (vLLM throughput at various concurrency)
228-
8. Any regressions observed (fused slower than split, test failures)
220+
Please include:
221+
1. GPU: `rocminfo | grep "Name:" | head -5`
222+
2. Software: `python -c "import torch; print(torch.__version__)"`
223+
3. Phase A: `bench_quick.py` output + `bench_vllm_sweep.py --limit 1` output
224+
4. Phase B: `bench_quick.py` output
225+
5. Phase C: `bench_vllm_sweep.py --limit 16` output + `bench_vllm_full.py` output at reqs=24,32
226+
6. Correctness: `pytest tests/test_ops.py -k test_gemv_4bit` and `pytest tests/test_linear4bit.py`
227+
7. Any regressions: fused slower than split, test failures, or throughput drops at reqs>16

bench_vllm_full.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
MODEL_REGISTRY = {
88
"mistral7b": ("unsloth/mistral-7b-instruct-v0.3-bnb-4bit", 1024),
99
"llama8b": ("unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", 1024),
10-
"gemma12b": ("unsloth/gemma-3-12b-it-bnb-4bit", 1024),
11-
"gemma27b": ("unsloth/gemma-3-27b-it-bnb-4bit", 512),
12-
"qwen27b": ("unsloth/Qwen3.5-27B", 256),
13-
"phi4": ("unsloth/phi-4-unsloth-bnb-4bit", 512),
10+
"qwen9b": ("Qwen/Qwen3.5-9B", 512),
1411
}
1512

1613

0 commit comments

Comments
 (0)