Skip to content

Commit c267f8e

Browse files
Copilottianleiwu
andauthored
Optimize MatMulNBits 2-bit + float zero_point CPU dequantization with multi-threaded kernel (#28589)
### Description Replace the naive single-threaded scalar loop for 2-bit dequantization with float/MLFloat16 zero points with a multi-threaded kernel (`DequantizeBlockwise2Bits`) that: - **Parallelizes via `TrySimpleParallelFor`** — distributes work across all intra-op threads (previously single-threaded) - **Processes 16 elements per iteration** — one `uint32_t` = 16 packed 2-bit values, reducing per-element overhead - **Hoists scale/zp lookups** — all 16 elements share a block, so scale and zero_point are loaded once per batch Follows the same threading pattern as the existing 4-bit `DequantizeBlockwise` path for consistency. **Files changed:** - `matmul_nbits_impl.h` — declare `DequantizeBlockwise2Bits` - `matmul_nbits_impl.cc` — implement `Dequantize2BitsKernel` + `DequantizeBlockwise2Bits` with instantiations for `<float,float>` and `<float,MLFloat16>` - `matmul_nbits.cc` — replace naive loops in both `MatMulNBits<float>` and `MatMulNBits<MLFloat16>` `ComputeBUnpacked` ### Motivation and Context The `bits=2` + float zero_point path (added in #28354) was flagged with `// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!`. It ran ~20× slower than the `bits=4` MLAS path because it was a tight scalar `for n × for k` loop with no threading — the entire N×K dequantization ran on a single core before calling `MlasGemmBatch`. With 8 intra-op threads this should recover most of that gap. ### Benchmark Results Tested on a 96-core x86_64 Linux machine, ORT 1.27.0 CPU Release build, using typical LLM matrix shapes with `block_size=128` and float zero points. #### Multi-thread speedup (2-bit dequantization, 1 thread → 8 threads) | Shape (M×K×N) | 1-thread (ms) | 8-thread (ms) | Speedup | |---|---|---|---| | 1×4096×4096 | 41.0 | 8.5 | **4.84×** | | 32×4096×4096 | 47.9 | 8.8 | **5.46×** | | 1×4096×11008 | 120.7 | 24.2 | **4.99×** | | 32×4096×11008 | 146.8 | 28.2 | **5.21×** | | 1×11008×4096 | 119.2 | 24.5 | **4.87×** | | 32×11008×4096 | 154.4 | 28.2 | **5.47×** | | 1×1024×1024 | 1.18 | 0.16 | **7.61×** | #### 2-bit vs 4-bit comparison (ratio = 2-bit / 4-bit; <1.0 means 2-bit is faster) | Shape (M×K×N) | Threads | 4-bit (ms) | 2-bit (ms) | Ratio | |---|---|---|---|---| | 1×4096×4096 | 1 | 52.0 | 41.0 | **0.79×** | | 1×4096×4096 | 8 | 9.4 | 8.5 | **0.90×** | | 1×4096×11008 | 1 | 141.6 | 120.7 | **0.85×** | | 1×4096×11008 | 8 | 26.8 | 24.2 | **0.90×** | | 1×11008×4096 | 1 | 141.2 | 119.2 | **0.84×** | | 1×11008×4096 | 8 | 26.6 | 24.5 | **0.92×** | | 32×4096×4096 | 1 | 56.1 | 47.9 | **0.85×** | | 32×4096×4096 | 8 | 9.6 | 8.8 | **0.92×** | | 1×1024×1024 | 1 | 1.66 | 1.18 | **0.71×** | **Key findings:** - Multi-threading delivers **4.8–7.6× speedup** with 8 threads across all LLM shapes - 2-bit is now **10–30% faster** than 4-bit (ratio 0.71–0.93×), due to fewer bytes read from memory - The original ~20× regression (issue #28552) is fully resolved --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
1 parent c5afcc5 commit c267f8e

4 files changed

Lines changed: 502 additions & 43 deletions

File tree

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -935,30 +935,22 @@ Status MatMulNBits<float>::ComputeBUnpacked(const Tensor* a,
935935
"Only 2b and 4b quantization is supported for unpacked compute using "
936936
"non-MLAS de-quantization for now");
937937

938-
// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
939938
// Note: The kernel registration constrains T3 to {uint8_t, T1}, so for
940939
// MatMulNBits<float> only float (not MLFloat16) ZP can reach this branch.
941940
if (zero_points && zero_points->IsDataType<float>()) {
942941
if (nbits_ == 2) {
943942
ORT_ENFORCE(reorder_idx_data == nullptr,
944-
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
945-
// Simple 2-bit dequantization with float zero points
946-
const float* float_zp = static_cast<const float*>(zero_points_data);
947-
size_t k_blocks = (K_ + block_size_ - 1) / block_size_;
948-
size_t packed_k = k_blocks * block_size_;
949-
size_t bytes_per_col = packed_k / 4;
950-
for (size_t n = 0; n < N_; n++) {
951-
for (size_t k = 0; k < K_; k++) {
952-
size_t block_idx = k / block_size_;
953-
float scale = scales_data[n * k_blocks + block_idx];
954-
float zp = float_zp[n * k_blocks + block_idx];
955-
size_t packed_idx = n * bytes_per_col + k / 4;
956-
int bit_offset = static_cast<int>((k % 4) * 2);
957-
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
958-
tmp_b_data_ptr.get()[n * K_ + k] =
959-
(static_cast<float>(q) - zp) * scale;
960-
}
961-
}
943+
"g_idx (reorder index) is not supported for 2-bit quantization with floating-point zero points");
944+
DequantizeBlockwise2Bits<float, float>(
945+
tmp_b_data_ptr.get(),
946+
b_data,
947+
scales_data,
948+
static_cast<const float*>(zero_points_data),
949+
static_cast<int32_t>(block_size_),
950+
column_wise_quant_,
951+
static_cast<int32_t>(K_),
952+
static_cast<int32_t>(N_),
953+
thread_pool);
962954
} else {
963955
DequantizeBlockwise<float, float>(
964956
tmp_b_data_ptr.get(), // dequantized output
@@ -1096,30 +1088,22 @@ Status MatMulNBits<MLFloat16>::ComputeBUnpacked(const Tensor* a,
10961088
"Only 2b and 4b quantization is supported for unpacked compute using "
10971089
"non-MLAS de-quantization for now");
10981090

1099-
// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
11001091
// Note: The kernel registration constrains T3 to {uint8_t, T1}, so for
11011092
// MatMulNBits<MLFloat16> only MLFloat16 (not float) ZP can reach this branch.
11021093
if (zero_points && zero_points->IsDataType<MLFloat16>()) {
11031094
if (nbits_ == 2) {
11041095
ORT_ENFORCE(reorder_idx_data == nullptr,
1105-
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
1106-
// Simple 2-bit dequantization with MLFloat16 zero points
1107-
const MLFloat16* fp16_zp = static_cast<const MLFloat16*>(zero_points_data);
1108-
size_t k_blocks = (K_ + block_size_ - 1) / block_size_;
1109-
size_t packed_k = k_blocks * block_size_;
1110-
size_t bytes_per_col = packed_k / 4;
1111-
for (size_t n = 0; n < N_; n++) {
1112-
for (size_t k = 0; k < K_; k++) {
1113-
size_t block_idx = k / block_size_;
1114-
float scale = scales_ptr[n * k_blocks + block_idx];
1115-
float zp = fp16_zp[n * k_blocks + block_idx].ToFloat();
1116-
size_t packed_idx = n * bytes_per_col + k / 4;
1117-
int bit_offset = static_cast<int>((k % 4) * 2);
1118-
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
1119-
tmp_b_data_ptr.get()[n * K_ + k] =
1120-
(static_cast<float>(q) - zp) * scale;
1121-
}
1122-
}
1096+
"g_idx (reorder index) is not supported for 2-bit quantization with floating-point zero points");
1097+
DequantizeBlockwise2Bits<float, MLFloat16>(
1098+
tmp_b_data_ptr.get(),
1099+
b_data,
1100+
scales_ptr,
1101+
static_cast<const MLFloat16*>(zero_points_data),
1102+
static_cast<int32_t>(block_size_),
1103+
column_wise_quant_,
1104+
static_cast<int32_t>(K_),
1105+
static_cast<int32_t>(N_),
1106+
thread_pool);
11231107
} else {
11241108
DequantizeBlockwise<float, MLFloat16>(
11251109
tmp_b_data_ptr.get(), // dequantized output

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc

Lines changed: 146 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cassert>
77
#include <cmath>
88
#include <cstdint>
9+
#include <cstring>
910
#include <type_traits>
1011

1112
#include "core/common/common.h"
@@ -41,11 +42,11 @@ void Dequantize4BitsKernelReOrder(
4142
T* output_i = output + out_y * out_cols + out_x;
4243
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
4344
if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) {
44-
const uint8_t* c = (const uint8_t*)(&quant_value);
45-
quant_value = (uint32_t)c[0] |
46-
(uint32_t)c[1] << 8 |
47-
(uint32_t)c[2] << 16 |
48-
(uint32_t)c[3] << 24;
45+
const uint8_t* c = reinterpret_cast<const uint8_t*>(&quant_value);
46+
quant_value = static_cast<uint32_t>(c[0]) |
47+
static_cast<uint32_t>(c[1]) << 8 |
48+
static_cast<uint32_t>(c[2]) << 16 |
49+
static_cast<uint32_t>(c[3]) << 24;
4950
}
5051
const int remain_x = std::min(8, out_cols - out_x);
5152
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1));
@@ -117,5 +118,145 @@ template void DequantizeBlockwise<float, MLFloat16, 4>(
117118
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
118119
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
119120

121+
// 2-bit dequantization kernel for float/MLFloat16 zero points.
122+
// Processes 16 elements at a time (16 x 2-bit = 32 bits = one uint32_t).
123+
// Layout: columnwise packing — elements within a column are packed consecutively,
124+
// output[n * K + k] = (quant_value - zp) * scale
125+
template <class T, class zeroT>
126+
void Dequantize2BitsKernel(
127+
T* output, const uint8_t* quant_data, const T* scale_data,
128+
const zeroT* zero_points, int block_size,
129+
int groups_per_threadblock, int total_groups, int N, int K,
130+
int blockIdx_x, int threadIdx_x) {
131+
// Each "thread" handles 16 elements (one uint32 of packed 2-bit values)
132+
constexpr int elements_per_thread = 16;
133+
const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * elements_per_thread) / block_size);
134+
if (group_id >= total_groups) {
135+
return;
136+
}
137+
const int k_blocks = (K + block_size - 1) / block_size;
138+
139+
int n_idx = group_id / k_blocks;
140+
int kb_idx = group_id % k_blocks;
141+
int element_offset = group_id * block_size + ((threadIdx_x * elements_per_thread) & (block_size - 1));
142+
143+
const int k_offset = element_offset % (k_blocks * block_size);
144+
const int n_offset = element_offset / (k_blocks * block_size);
145+
if (n_offset >= N || k_offset >= K) {
146+
return;
147+
}
148+
149+
T* output_i = output + n_offset * K + k_offset;
150+
// 16 elements × 2 bits = 4 bytes. Use memcpy to avoid alignment UB.
151+
uint32_t quant_value = 0;
152+
std::memcpy(&quant_value, quant_data + element_offset / 4, sizeof(uint32_t));
153+
if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) {
154+
const uint8_t* c = reinterpret_cast<const uint8_t*>(&quant_value);
155+
quant_value = static_cast<uint32_t>(c[0]) |
156+
static_cast<uint32_t>(c[1]) << 8 |
157+
static_cast<uint32_t>(c[2]) << 16 |
158+
static_cast<uint32_t>(c[3]) << 24;
159+
}
160+
const int remain_k = std::min(elements_per_thread, K - k_offset);
161+
162+
float scale_f = static_cast<float>(*(scale_data + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx)));
163+
float zp_f = 0.0f;
164+
if (zero_points) {
165+
if constexpr (std::is_same_v<zeroT, MLFloat16>) {
166+
zp_f = (*(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx))).ToFloat();
167+
} else {
168+
zp_f = static_cast<float>(*(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx)));
169+
}
170+
}
171+
172+
float zp_adjust = -scale_f * zp_f;
173+
for (int i = 0; i < remain_k; i++) {
174+
float q = static_cast<float>((quant_value >> (2 * i)) & 0x3);
175+
output_i[i] = static_cast<T>(q * scale_f + zp_adjust);
176+
}
177+
}
178+
179+
template <class T, class zeroT>
180+
void Dequantize2BitsFallback(
181+
T* output, const uint8_t* quant_data, const T* scale_data,
182+
const zeroT* zero_points, int block_size, int N, int K) {
183+
const int k_blocks = (K + block_size - 1) / block_size;
184+
185+
for (int n = 0; n < N; ++n) {
186+
for (int kb = 0; kb < k_blocks; ++kb) {
187+
const int group_offset = (n * k_blocks + kb) * block_size;
188+
const int k_start = kb * block_size;
189+
const int k_count = std::min(block_size, K - k_start);
190+
191+
const float scale = static_cast<float>(scale_data[static_cast<uint64_t>(n) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb)]);
192+
float zp_f = 0.0f;
193+
if (zero_points) {
194+
if constexpr (std::is_same_v<zeroT, MLFloat16>) {
195+
zp_f = zero_points[static_cast<uint64_t>(n) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb)].ToFloat();
196+
} else {
197+
zp_f = static_cast<float>(zero_points[static_cast<uint64_t>(n) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb)]);
198+
}
199+
}
200+
const float zp_adjust = -scale * zp_f;
201+
T* output_i = output + static_cast<uint64_t>(n) * static_cast<uint64_t>(K) + static_cast<uint64_t>(k_start);
202+
203+
for (int i = 0; i < k_count; ++i) {
204+
const int element_offset = group_offset + i;
205+
const uint8_t packed = quant_data[element_offset / 4];
206+
const uint8_t q = (packed >> (2 * (element_offset & 0x3))) & 0x3;
207+
output_i[i] = static_cast<T>(static_cast<float>(q) * scale + zp_adjust);
208+
}
209+
}
210+
}
211+
}
212+
213+
// Specialization of DequantizeBlockwise for qbits=2
214+
template <typename inputT, typename zeroT>
215+
void DequantizeBlockwise2Bits(
216+
inputT* output,
217+
const uint8_t* quant_data,
218+
const inputT* scales_data,
219+
const zeroT* zero_points,
220+
int32_t block_size,
221+
bool columnwise,
222+
int32_t K,
223+
int32_t N,
224+
onnxruntime::concurrency::ThreadPool* pool) {
225+
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
226+
constexpr int elements_per_thread = 16;
227+
ORT_ENFORCE(columnwise, "Row-wise quantization is not supported");
228+
ORT_ENFORCE(block_size > 0, "block_size must be positive, got: ", block_size);
229+
ORT_ENFORCE((block_size & (block_size - 1)) == 0, "block_size must be a power of two, got: ", block_size);
230+
if (block_size > 256 * elements_per_thread || block_size % elements_per_thread != 0) {
231+
Dequantize2BitsFallback(output, quant_data, scales_data, zero_points, block_size, N, K);
232+
return;
233+
}
234+
235+
int groups_per_threadblock = 256 * elements_per_thread / block_size;
236+
int groups_per_K = ceildiv(K, block_size);
237+
int total_groups = N * groups_per_K;
238+
int blocks_per_grid = static_cast<int>(ceildiv(total_groups, groups_per_threadblock));
239+
concurrency::ThreadPool::TrySimpleParallelFor(
240+
pool, static_cast<std::ptrdiff_t>(blocks_per_grid),
241+
[&](std::ptrdiff_t block_id) {
242+
for (int j = 0; j < 256; j++) {
243+
Dequantize2BitsKernel(output, quant_data, scales_data, zero_points,
244+
block_size, groups_per_threadblock,
245+
total_groups, N, K, static_cast<int>(block_id), j);
246+
}
247+
});
248+
}
249+
250+
// Explicit instantiations for 2-bit dequantization
251+
template void DequantizeBlockwise2Bits<float, float>(
252+
float* output, const uint8_t* quant_data, const float* scales_data,
253+
const float* zero_points, int32_t block_size,
254+
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
255+
256+
template void DequantizeBlockwise2Bits<float, MLFloat16>(
257+
float* output, const uint8_t* quant_data, const float* scales_data,
258+
const MLFloat16* zero_points, int32_t block_size,
259+
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
260+
120261
} // namespace contrib
121262
} // namespace onnxruntime

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,19 @@ void DequantizeBlockwise(
1919
int32_t N, // number of columns in quantized input
2020
onnxruntime::concurrency::ThreadPool* thread_pool);
2121

22+
// Threaded 2-bit blockwise dequantization with float/MLFloat16 zero points.
23+
// Does not support reorder_idx (g_idx).
24+
template <typename inputT, typename zeroT>
25+
void DequantizeBlockwise2Bits(
26+
inputT* output, // dequantized output
27+
const uint8_t* quant_data, // quantized input
28+
const inputT* scales_data, // quantization scales
29+
const zeroT* zero_points, // quantization zero points
30+
int32_t block_size, // quantization block size
31+
bool columnwise, // columnwise quantization or row-wise
32+
int32_t K, // number of rows in quantized input
33+
int32_t N, // number of columns in quantized input
34+
onnxruntime::concurrency::ThreadPool* thread_pool);
35+
2236
} // namespace contrib
2337
} // namespace onnxruntime

0 commit comments

Comments
 (0)