[ROCm] Optimize kgemm_4bit_inference_naive for ROCm, use it for batch sizes other than 1#1920
Open
sstamenk wants to merge 1 commit intobitsandbytes-foundation:mainfrom
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Based on issues raised in #1842 and pytorch#171687.
Summary
kgemm_4bit_inference_naiveon ROCm, following the suggestions discussed in # 70B 4-bit LLM decode bottlenecked by HIP kernel (kgemm_4bit_inference_naive) efficiency — 49% vs 91% memory bandwidth on ROCm/gfx1151 #1842.Req = 2.M > 1) inputs instead of only the vector case.Technical details
This PR makes two related changes.
Kernel optimization
kgemm_4bit_inference_naiveto reduce overhead in the fused dequantize + matmul path on ROCm.kgemm_4bit_inference_naive) efficiency — 49% vs 91% memory bandwidth on ROCm/gfx1151 #1842Fused path support for
M > 1M == 1.Up to a platform-specific crossover point, launching the fused kernel is substantially faster than falling back to split dequantize + GEMM. This matters most for serving workloads, where decode steps regularly hit small
M > 1batches.Example measured on Strix Halo:
At larger
M, the fused path eventually converges with and then regresses against split dequantize + GEMM. The crossover differs by GPU:gfx115116RTX 50908-12gfx120110-12MI308X4-6For this PR, the dispatch threshold is set to
M=8as a cross-GPU compromise. That still leaves some regressions on MI308X oncereqs >= 6, but avoids the larger regressions seen at higher thresholds on other GPUs.Testing plan
gemm_4bitunit tests to validate correctness of the updated kernel path.M > 1fused-path performance gain.Testing results
gemv_4bit unit-tests
kgemm_4bit_inference_naivebenchmarkIn this table,
Adenotes the baseline kernel andBdenotes the optimized kernel.gfx11511133 us740 us117 GB/s178 GB/s~210 GB/s (measured)56%85%1.53xRTX 509086 us84 us1361 GB/s1394 GB/s~1,790 GB/s76%78%1.02xgfx1201539 us226 us218 GB/s519 GB/s640 GB/s34%81%2.39xMI308X656 us246 us179 GB/s477 GB/s~3,277 GB/s5.5%14.6%2.67xEnd-to-end Transformers Throughput
Strix Halo(gfx1151):2.453.861.58x18.327.91.53x10.715.61.46x9.612.51.30x17.422.31.28xRTX 5090:85.5784.320.99x82.4380.760.98x59.5158.950.99xRadeon AI Pro R9700(gfx1201):38.4264.321.67x31.3146.271.48x23.8332.001.34xMI308X(gfx942):31.2840.511.30x30.9740.461.31x23.4529.031.24x15.6215.941.02x4.5910.602.31x4.5910.612.31xEnd-to-end vLLM Serving Throughput for
Reqs > 1Strix Halo(gfx1151)Mistral-7B
22.5fused34.7fused35.3fused35.7fused1.59x10.4split10.4split51.8fused53.4fused5.13x20.3split20.3split67.6fused67.3fused3.32x30.4split30.5split72.4fused71.9fused2.37x40.4split40.4split75.2fused75.1fused1.86x50.5split50.6split50.6split76.7fused1.52x60.2split60.4split60.4split77.9fused1.29x69.9split70.1split70.1split78.7fused1.13x80.1split80.2split80.3split79.2fused0.99xLlama-8B
20.7fused32.0fused32.0fused32.0fused1.55x10.4split10.4split48.9fused47.3fused4.55x20.3split20.3split63.7fused63.6fused3.13x30.3split30.2split68.8fused68.5fused2.26x40.2split40.1split72.4fused72.4fused1.80x50.2split50.2split50.2split74.6fused1.49x60.0split60.0split59.9split75.8fused1.26x69.5split69.6split69.6split76.9fused1.11x79.5split79.5split79.5split77.2fused0.97xQwen3.5-9B
17.5fused23.1fused22.9fused22.9fused1.31x9.4split9.4split40.4fused39.5fused4.20x18.4split18.4split55.8fused57.0fused3.10x26.9split27.0split61.8fused62.0fused2.30x35.4split35.5split65.9fused66.0fused1.86x44.5split44.6split44.7split68.7fused1.54x52.0split52.3split52.3split70.6fused1.36x60.1split60.3split60.4split72.1fused1.20x69.0split69.3split69.5split72.9fused1.06xLlama-3.3-70B
2.5fused4.1fused4.1fused4.0fused1.60x1.2split1.2split5.9fused5.8fused4.83x2.4split-7.4fused7.4fused3.08x3.6split-7.8fused7.8fused2.17x4.8split-8.0fused8.0fused1.67x6.0split-6.0split8.1fused1.35x7.2split--8.2fused1.14x8.4split--8.3fused0.99x9.5split--8.3fused0.87xRTX 5090Mistral-7B
134.9fused136.0fused135.3fused131.4fused1.00x0.97x104.0split103.9split255.5fused243.8fused2.46x2.34x204.9split204.6split347.0fused343.1fused1.69x1.67x283.0split283.0split385.1fused382.0fused1.36x1.35x375.8split376.0split404.5fused401.9fused1.08x1.07x422.4split422.4split420.9split407.0fused1.00x0.96x469.1split468.4split469.3split411.9fused1.00x0.88x558.4split559.8split558.6split415.9fused1.00x0.74x736.8split736.7split737.3split425.5fused1.00x0.58xLlama-8B
136.6fused134.1fused133.0fused133.3fused0.97x0.98x101.5split101.4split251.4fused245.6fused2.48x2.42x200.0split199.5split333.6fused330.7fused1.67x1.65x275.7split275.7split373.8fused374.3fused1.36x1.36x365.9split365.0split394.3fused395.4fused1.08x1.08x410.7split410.8split411.0split399.3fused1.00x0.97x456.0split456.0split456.8split404.5fused1.00x0.89x544.9split545.2split545.4split410.1fused1.00x0.75x720.7split720.5split720.7split420.5fused1.00x0.58xQwen3.5-9B
72.3fused72.6fused73.4fused72.2fused1.02x1.00x100.0split100.0split135.3fused132.4fused1.35x1.32x188.7split188.5split271.1fused264.8fused1.44x1.40x280.0split280.0split344.4fused343.2fused1.23x1.23x370.4split370.4split369.3fused368.2fused1.00x0.99x415.6split415.7split415.6split375.0fused1.00x0.90x462.1split462.1split462.4split382.2fused1.00x0.83x545.8split545.8split545.9split390.6fused1.00x0.72x737.9split738.1split738.1split400.8fused1.00x0.54xRadeon AI Pro R9700(gfx1201)Mistral-7B
45.5fused87.2fused90.0fused88.1fused1.98x1.94x34.1split34.2split127.9fused119.6fused3.75x3.51x68.1split67.7split150.4fused147.1fused2.21x2.16x134.2split134.2split166.6fused163.9fused1.24x1.22x151.2split150.7split151.0split167.2fused1.00x1.11x167.7split167.0split167.1split167.3fused1.00x1.00x184.1split183.3split183.6split166.9fused1.00x0.91x199.0split198.7split199.7split169.3fused1.00x0.85x263.2split261.5split262.6split170.2fused1.00x0.65xLlama-8B
44.1fused80.9fused80.9fused79.7fused1.84x1.81x33.5split33.4split117.8fused112.1fused3.52x3.35x66.6split66.4split142.1fused140.7fused2.13x2.11x132.4split132.0split159.7fused160.6fused1.21x1.21x147.3split147.1split147.1split162.5fused1.00x1.10x163.1split162.9split162.8split162.4fused1.00x1.00x179.2split178.4split178.7split160.4fused1.00x0.90x195.1split194.9split194.8split165.5fused1.00x0.85x256.1split255.8split256.3split167.5fused1.00x0.65xQwen3.5-9B
9.4fused10.8fused10.8fused10.8fused1.15x1.15x13.4split13.5split19.8fused19.7fused1.48x1.47x26.8split26.7split36.3fused36.3fused1.35x1.35x52.7split52.7split61.3fused61.1fused1.16x1.16x59.3split59.2split59.3split64.1fused1.00x1.08x65.4split65.4split65.4split69.2fused1.00x1.06x72.3split72.3split72.2split73.6fused1.00x1.02x78.7split78.6split78.6split78.0fused1.00x0.99x103.2split103.1split103.2split92.1fused1.00x0.89xMI308X(gfx942)Mistral-7B
37.6fused61.3fused64.2fused61.7fused1.63x1.71x1.64x47.3split48.0split98.8fused92.1fused1.01x2.09x1.95x94.4split95.1split112.8fused111.1fused1.01x1.19x1.18x141.3split142.0split120.2fused118.8fused1.00x0.85x0.84x188.5split189.2split124.0fused123.2fused1.00x0.66x0.65x214.0split215.6split192.7split124.6fused1.01x0.90x0.58x237.5split239.0split238.9split125.9fused1.01x1.01x0.53x284.8split287.1split285.6split128.9fused1.01x1.00x0.45x379.2split382.0split381.3split130.1fused1.01x1.01x0.34xLlama-8B
37.3fused63.9fused64.9fused62.1fused1.71x1.74x1.66x47.1split47.6split96.3fused91.5fused1.01x2.04x1.94x90.8split93.7split111.7fused110.3fused1.03x1.23x1.21x139.8split139.0split117.9fused117.6fused0.99x0.84x0.84x186.9split185.7split122.6fused121.6fused0.99x0.66x0.65x212.8split213.2split211.8split123.7fused1.00x1.00x0.58x235.6split237.0split236.9split125.3fused1.01x1.01x0.53x283.2split283.3split284.5split127.0fused1.00x1.00x0.45x376.5split377.2split377.6split129.8fused1.00x1.00x0.34xQwen3.5-9B
30.2fused30.0fused30.1fused29.8fused0.99x1.00x0.99x40.0split39.8split57.6fused58.8fused1.00x1.44x1.47x79.7split78.9split96.9fused96.5fused0.99x1.22x1.21x119.4split119.0split107.0fused106.2fused1.00x0.90x0.89x159.4split159.0split112.2fused112.0fused1.00x0.70x0.70x179.7split178.9split177.8split114.2fused1.00x0.99x0.64x199.3split182.9split197.7split115.9fused0.92x0.99x0.58x239.6split236.7split224.1split118.5fused0.99x0.94x0.49x317.6split318.1split315.2split122.0fused1.00x0.99x0.38xLlama-3.3-70B
4.7fused11.3fused11.3fused10.7fused2.40x2.40x2.28x5.4split5.4split12.6fused11.9fused1.00x2.33x2.20x10.7split10.7split13.8fused13.5fused1.00x1.29x1.26x16.1split16.1split14.3fused14.1fused1.00x0.89x0.88x21.4split21.4split14.6fused14.4fused1.00x0.68x0.67x24.2split24.2split24.1split14.7fused1.00x1.00x0.61x26.9split26.9split26.9split14.8fused1.00x1.00x0.55x32.1split32.1split32.1split14.9fused1.00x1.00x0.46x42.7split42.7split42.7split14.9fused1.00x1.00x0.35x