Skip to content

[ROCm] Optimize kgemm_4bit_inference_naive for ROCm, use it for batch sizes other than 1#1920

Open
sstamenk wants to merge 1 commit intobitsandbytes-foundation:mainfrom
sstamenk:rocm-4bit-kernel-optimization
Open

[ROCm] Optimize kgemm_4bit_inference_naive for ROCm, use it for batch sizes other than 1#1920
sstamenk wants to merge 1 commit intobitsandbytes-foundation:mainfrom
sstamenk:rocm-4bit-kernel-optimization

Conversation

@sstamenk
Copy link
Copy Markdown
Contributor

@sstamenk sstamenk commented Apr 11, 2026

Based on issues raised in #1842 and pytorch#171687.

Summary

  • Optimize kgemm_4bit_inference_naive on ROCm, following the suggestions discussed in # 70B 4-bit LLM decode bottlenecked by HIP kernel (kgemm_4bit_inference_naive) efficiency — 49% vs 91% memory bandwidth on ROCm/gfx1151 #1842.
  • On gfx1201 this improves the kernel microbenchmark by up to 2.39x and end-to-end inference by up to 1.98x. On gfx1151 vLLM serving throughput improves by up to 5.13x at Req = 2.
  • The kernel optimizations don't improve or regress the performance on Nvidia GPUs. Nvidia however still benefits from the vLLM serving optimizations.
  • Extend the fused 4-bit inference path to support small multi-row (M > 1) inputs instead of only the vector case.
  • Update the Python and backend dispatch path so compatible small-batch inference uses the fused kernel automatically, improving decode throughput for multi-request vLLM serving with 4-bit quantization.

Technical details

This PR makes two related changes.

Kernel optimization

Fused path support for M > 1

  • Remove the vector-only restriction in the 4-bit fused inference path.
  • Pass the real number of input rows through the Python/backend interface.
  • Update the ROCm launch configuration so the fused kernel is used for small multi-row inputs, not just M == 1.

Up to a platform-specific crossover point, launching the fused kernel is substantially faster than falling back to split dequantize + GEMM. This matters most for serving workloads, where decode steps regularly hit small M > 1 batches.

Example measured on Strix Halo:

M Split Fused Speedup
1 741 us 741 us 1.00x
2 5512 us 998 us 5.52x
4 5498 us 1630 us 3.37x
8 5514 us 3001 us 1.84x

At larger M, the fused path eventually converges with and then regresses against split dequantize + GEMM. The crossover differs by GPU:

GPU Crossover M
gfx1151 16
RTX 5090 8-12
gfx1201 10-12
MI308X 4-6

For this PR, the dispatch threshold is set to M=8 as a cross-GPU compromise. That still leaves some regressions on MI308X once reqs >= 6, but avoids the larger regressions seen at higher thresholds on other GPUs.

Testing plan

  • Run the gemm_4bit unit tests to validate correctness of the updated kernel path.
  • Run end-to-end Transformers inference in 4-bit format to validate single-user decode behavior.
  • Run end-to-end vLLM serving with multiple concurrent requests to verify the M > 1 fused-path performance gain.

Testing results

gemv_4bit unit-tests

  • All tests pass on all of the tested configurations.

kgemm_4bit_inference_naive benchmark

In this table, A denotes the baseline kernel and B denotes the optimized kernel.

GPU Time A Time B BW A BW B Peak BW Reference % Peak A % Peak B Speedup (B vs A)
gfx1151 1133 us 740 us 117 GB/s 178 GB/s ~210 GB/s (measured) 56% 85% 1.53x
RTX 5090 86 us 84 us 1361 GB/s 1394 GB/s ~1,790 GB/s 76% 78% 1.02x
gfx1201 539 us 226 us 218 GB/s 519 GB/s 640 GB/s 34% 81% 2.39x
MI308X 656 us 246 us 179 GB/s 477 GB/s ~3,277 GB/s 5.5% 14.6% 2.67x

End-to-end Transformers Throughput

Strix Halo (gfx1151):

Model A (tok/s) B (tok/s) Speedup
Llama-3.3-70B-Instruct 2.45 3.86 1.58x
Mistral-7B-Instruct-v0.3 18.3 27.9 1.53x
Phi-4 (14B) 10.7 15.6 1.46x
DeepSeek-R1-Distill-Qwen-14B 9.6 12.5 1.30x
DeepSeek-R1-Distill-Qwen-7B 17.4 22.3 1.28x

RTX 5090:

Model Phase A (tok/s) Phase B (tok/s) Speedup
Mistral-7B 85.57 84.32 0.99x
Llama-8B 82.43 80.76 0.98x
Qwen3.5-9B 59.51 58.95 0.99x

Radeon AI Pro R9700 (gfx1201):

Model Phase A (tok/s) Phase B (tok/s) Speedup
Mistral-7B 38.42 64.32 1.67x
Llama-8B 31.31 46.27 1.48x
Qwen3.5-9B 23.83 32.00 1.34x

MI308X (gfx942):

Model Phase A (tok/s) Phase B (tok/s) Speedup
Mistral-7B 31.28 40.51 1.30x
Llama-3.1-8B 30.97 40.46 1.31x
Qwen3.5-9B 23.45 29.03 1.24x
Qwen3.5-35B-A3B (MoE) 15.62 15.94 1.02x
Llama-3.3-70B-Instruct 4.59 10.60 2.31x
Llama-3.2-90B-Vision (text-only) 4.59 10.61 2.31x

End-to-end vLLM Serving Throughput for Reqs > 1

Strix Halo (gfx1151)

Mistral-7B
Reqs Baseline L=1 L=8 L=16 L=16 vs Baseline
1 22.5 fused 34.7 fused 35.3 fused 35.7 fused 1.59x
2 10.4 split 10.4 split 51.8 fused 53.4 fused 5.13x
4 20.3 split 20.3 split 67.6 fused 67.3 fused 3.32x
6 30.4 split 30.5 split 72.4 fused 71.9 fused 2.37x
8 40.4 split 40.4 split 75.2 fused 75.1 fused 1.86x
10 50.5 split 50.6 split 50.6 split 76.7 fused 1.52x
12 60.2 split 60.4 split 60.4 split 77.9 fused 1.29x
14 69.9 split 70.1 split 70.1 split 78.7 fused 1.13x
16 80.1 split 80.2 split 80.3 split 79.2 fused 0.99x
Llama-8B
Reqs Baseline L=1 L=8 L=16 L=16 vs Baseline
1 20.7 fused 32.0 fused 32.0 fused 32.0 fused 1.55x
2 10.4 split 10.4 split 48.9 fused 47.3 fused 4.55x
4 20.3 split 20.3 split 63.7 fused 63.6 fused 3.13x
6 30.3 split 30.2 split 68.8 fused 68.5 fused 2.26x
8 40.2 split 40.1 split 72.4 fused 72.4 fused 1.80x
10 50.2 split 50.2 split 50.2 split 74.6 fused 1.49x
12 60.0 split 60.0 split 59.9 split 75.8 fused 1.26x
14 69.5 split 69.6 split 69.6 split 76.9 fused 1.11x
16 79.5 split 79.5 split 79.5 split 77.2 fused 0.97x
Qwen3.5-9B
Reqs Baseline L=1 L=8 L=16 L=16 vs Baseline
1 17.5 fused 23.1 fused 22.9 fused 22.9 fused 1.31x
2 9.4 split 9.4 split 40.4 fused 39.5 fused 4.20x
4 18.4 split 18.4 split 55.8 fused 57.0 fused 3.10x
6 26.9 split 27.0 split 61.8 fused 62.0 fused 2.30x
8 35.4 split 35.5 split 65.9 fused 66.0 fused 1.86x
10 44.5 split 44.6 split 44.7 split 68.7 fused 1.54x
12 52.0 split 52.3 split 52.3 split 70.6 fused 1.36x
14 60.1 split 60.3 split 60.4 split 72.1 fused 1.20x
16 69.0 split 69.3 split 69.5 split 72.9 fused 1.06x
Llama-3.3-70B
Reqs Baseline L=1 L=8 L=16 L=16 vs Baseline
1 2.5 fused 4.1 fused 4.1 fused 4.0 fused 1.60x
2 1.2 split 1.2 split 5.9 fused 5.8 fused 4.83x
4 2.4 split - 7.4 fused 7.4 fused 3.08x
6 3.6 split - 7.8 fused 7.8 fused 2.17x
8 4.8 split - 8.0 fused 8.0 fused 1.67x
10 6.0 split - 6.0 split 8.1 fused 1.35x
12 7.2 split - - 8.2 fused 1.14x
14 8.4 split - - 8.3 fused 0.99x
16 9.5 split - - 8.3 fused 0.87x

RTX 5090

Mistral-7B
Reqs Baseline L=1 L=8 L=16 L=8 vs Baseline L=16 vs Baseline
1 134.9 fused 136.0 fused 135.3 fused 131.4 fused 1.00x 0.97x
2 104.0 split 103.9 split 255.5 fused 243.8 fused 2.46x 2.34x
4 204.9 split 204.6 split 347.0 fused 343.1 fused 1.69x 1.67x
6 283.0 split 283.0 split 385.1 fused 382.0 fused 1.36x 1.35x
8 375.8 split 376.0 split 404.5 fused 401.9 fused 1.08x 1.07x
9 422.4 split 422.4 split 420.9 split 407.0 fused 1.00x 0.96x
10 469.1 split 468.4 split 469.3 split 411.9 fused 1.00x 0.88x
12 558.4 split 559.8 split 558.6 split 415.9 fused 1.00x 0.74x
16 736.8 split 736.7 split 737.3 split 425.5 fused 1.00x 0.58x
Llama-8B
Reqs Baseline L=1 L=8 L=16 L=8 vs Baseline L=16 vs Baseline
1 136.6 fused 134.1 fused 133.0 fused 133.3 fused 0.97x 0.98x
2 101.5 split 101.4 split 251.4 fused 245.6 fused 2.48x 2.42x
4 200.0 split 199.5 split 333.6 fused 330.7 fused 1.67x 1.65x
6 275.7 split 275.7 split 373.8 fused 374.3 fused 1.36x 1.36x
8 365.9 split 365.0 split 394.3 fused 395.4 fused 1.08x 1.08x
9 410.7 split 410.8 split 411.0 split 399.3 fused 1.00x 0.97x
10 456.0 split 456.0 split 456.8 split 404.5 fused 1.00x 0.89x
12 544.9 split 545.2 split 545.4 split 410.1 fused 1.00x 0.75x
16 720.7 split 720.5 split 720.7 split 420.5 fused 1.00x 0.58x
Qwen3.5-9B
Reqs Baseline L=1 L=8 L=16 L=8 vs Baseline L=16 vs Baseline
1 72.3 fused 72.6 fused 73.4 fused 72.2 fused 1.02x 1.00x
2 100.0 split 100.0 split 135.3 fused 132.4 fused 1.35x 1.32x
4 188.7 split 188.5 split 271.1 fused 264.8 fused 1.44x 1.40x
6 280.0 split 280.0 split 344.4 fused 343.2 fused 1.23x 1.23x
8 370.4 split 370.4 split 369.3 fused 368.2 fused 1.00x 0.99x
9 415.6 split 415.7 split 415.6 split 375.0 fused 1.00x 0.90x
10 462.1 split 462.1 split 462.4 split 382.2 fused 1.00x 0.83x
12 545.8 split 545.8 split 545.9 split 390.6 fused 1.00x 0.72x
16 737.9 split 738.1 split 738.1 split 400.8 fused 1.00x 0.54x

Radeon AI Pro R9700 (gfx1201)

Mistral-7B
Reqs Baseline L=1 L=8 L=16 L=8 vs Baseline L=16 vs Baseline
1 45.5 fused 87.2 fused 90.0 fused 88.1 fused 1.98x 1.94x
2 34.1 split 34.2 split 127.9 fused 119.6 fused 3.75x 3.51x
4 68.1 split 67.7 split 150.4 fused 147.1 fused 2.21x 2.16x
8 134.2 split 134.2 split 166.6 fused 163.9 fused 1.24x 1.22x
9 151.2 split 150.7 split 151.0 split 167.2 fused 1.00x 1.11x
10 167.7 split 167.0 split 167.1 split 167.3 fused 1.00x 1.00x
11 184.1 split 183.3 split 183.6 split 166.9 fused 1.00x 0.91x
12 199.0 split 198.7 split 199.7 split 169.3 fused 1.00x 0.85x
16 263.2 split 261.5 split 262.6 split 170.2 fused 1.00x 0.65x
Llama-8B
Reqs Baseline L=1 L=8 L=16 L=8 vs Baseline L=16 vs Baseline
1 44.1 fused 80.9 fused 80.9 fused 79.7 fused 1.84x 1.81x
2 33.5 split 33.4 split 117.8 fused 112.1 fused 3.52x 3.35x
4 66.6 split 66.4 split 142.1 fused 140.7 fused 2.13x 2.11x
8 132.4 split 132.0 split 159.7 fused 160.6 fused 1.21x 1.21x
9 147.3 split 147.1 split 147.1 split 162.5 fused 1.00x 1.10x
10 163.1 split 162.9 split 162.8 split 162.4 fused 1.00x 1.00x
11 179.2 split 178.4 split 178.7 split 160.4 fused 1.00x 0.90x
12 195.1 split 194.9 split 194.8 split 165.5 fused 1.00x 0.85x
16 256.1 split 255.8 split 256.3 split 167.5 fused 1.00x 0.65x
Qwen3.5-9B
Reqs Baseline L=1 L=8 L=16 L=8 vs A L=16 vs Baseline
1 9.4 fused 10.8 fused 10.8 fused 10.8 fused 1.15x 1.15x
2 13.4 split 13.5 split 19.8 fused 19.7 fused 1.48x 1.47x
4 26.8 split 26.7 split 36.3 fused 36.3 fused 1.35x 1.35x
8 52.7 split 52.7 split 61.3 fused 61.1 fused 1.16x 1.16x
9 59.3 split 59.2 split 59.3 split 64.1 fused 1.00x 1.08x
10 65.4 split 65.4 split 65.4 split 69.2 fused 1.00x 1.06x
11 72.3 split 72.3 split 72.2 split 73.6 fused 1.00x 1.02x
12 78.7 split 78.6 split 78.6 split 78.0 fused 1.00x 0.99x
16 103.2 split 103.1 split 103.2 split 92.1 fused 1.00x 0.89x

MI308X (gfx942)

Mistral-7B
Reqs Baseline L=1 L=8 L=16 L=1 vs Base L=8 vs Baseline L=16 vs Baseline
1 37.6 fused 61.3 fused 64.2 fused 61.7 fused 1.63x 1.71x 1.64x
2 47.3 split 48.0 split 98.8 fused 92.1 fused 1.01x 2.09x 1.95x
4 94.4 split 95.1 split 112.8 fused 111.1 fused 1.01x 1.19x 1.18x
6 141.3 split 142.0 split 120.2 fused 118.8 fused 1.00x 0.85x 0.84x
8 188.5 split 189.2 split 124.0 fused 123.2 fused 1.00x 0.66x 0.65x
9 214.0 split 215.6 split 192.7 split 124.6 fused 1.01x 0.90x 0.58x
10 237.5 split 239.0 split 238.9 split 125.9 fused 1.01x 1.01x 0.53x
12 284.8 split 287.1 split 285.6 split 128.9 fused 1.01x 1.00x 0.45x
16 379.2 split 382.0 split 381.3 split 130.1 fused 1.01x 1.01x 0.34x
Llama-8B
Reqs Baseline L=1 L=8 L=16 L=1 vs Base L=8 vs Baseline L=16 vs Baseline
1 37.3 fused 63.9 fused 64.9 fused 62.1 fused 1.71x 1.74x 1.66x
2 47.1 split 47.6 split 96.3 fused 91.5 fused 1.01x 2.04x 1.94x
4 90.8 split 93.7 split 111.7 fused 110.3 fused 1.03x 1.23x 1.21x
6 139.8 split 139.0 split 117.9 fused 117.6 fused 0.99x 0.84x 0.84x
8 186.9 split 185.7 split 122.6 fused 121.6 fused 0.99x 0.66x 0.65x
9 212.8 split 213.2 split 211.8 split 123.7 fused 1.00x 1.00x 0.58x
10 235.6 split 237.0 split 236.9 split 125.3 fused 1.01x 1.01x 0.53x
12 283.2 split 283.3 split 284.5 split 127.0 fused 1.00x 1.00x 0.45x
16 376.5 split 377.2 split 377.6 split 129.8 fused 1.00x 1.00x 0.34x
Qwen3.5-9B
Reqs Baseline L=1 L=8 L=16 L=1 vs Base L=8 vs Baseline L=16 vs Baseline
1 30.2 fused 30.0 fused 30.1 fused 29.8 fused 0.99x 1.00x 0.99x
2 40.0 split 39.8 split 57.6 fused 58.8 fused 1.00x 1.44x 1.47x
4 79.7 split 78.9 split 96.9 fused 96.5 fused 0.99x 1.22x 1.21x
6 119.4 split 119.0 split 107.0 fused 106.2 fused 1.00x 0.90x 0.89x
8 159.4 split 159.0 split 112.2 fused 112.0 fused 1.00x 0.70x 0.70x
9 179.7 split 178.9 split 177.8 split 114.2 fused 1.00x 0.99x 0.64x
10 199.3 split 182.9 split 197.7 split 115.9 fused 0.92x 0.99x 0.58x
12 239.6 split 236.7 split 224.1 split 118.5 fused 0.99x 0.94x 0.49x
16 317.6 split 318.1 split 315.2 split 122.0 fused 1.00x 0.99x 0.38x
Llama-3.3-70B
Reqs Baseline L=1 L=8 L=16 L=1 vs Base L=8 vs Baseline L=16 vs Baseline
1 4.7 fused 11.3 fused 11.3 fused 10.7 fused 2.40x 2.40x 2.28x
2 5.4 split 5.4 split 12.6 fused 11.9 fused 1.00x 2.33x 2.20x
4 10.7 split 10.7 split 13.8 fused 13.5 fused 1.00x 1.29x 1.26x
6 16.1 split 16.1 split 14.3 fused 14.1 fused 1.00x 0.89x 0.88x
8 21.4 split 21.4 split 14.6 fused 14.4 fused 1.00x 0.68x 0.67x
9 24.2 split 24.2 split 24.1 split 14.7 fused 1.00x 1.00x 0.61x
10 26.9 split 26.9 split 26.9 split 14.8 fused 1.00x 1.00x 0.55x
12 32.1 split 32.1 split 32.1 split 14.9 fused 1.00x 1.00x 0.46x
16 42.7 split 42.7 split 42.7 split 14.9 fused 1.00x 1.00x 0.35x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant