Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 22 additions & 38 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -931,30 +931,22 @@ Status MatMulNBits<float>::ComputeBUnpacked(const Tensor* a,
"Only 2b and 4b quantization is supported for unpacked compute using "
"non-MLAS de-quantization for now");

// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
// Note: The kernel registration constrains T3 to {uint8_t, T1}, so for
// MatMulNBits<float> only float (not MLFloat16) ZP can reach this branch.
if (zero_points && zero_points->IsDataType<float>()) {
if (nbits_ == 2) {
ORT_ENFORCE(reorder_idx_data == nullptr,
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
// Simple 2-bit dequantization with float zero points
const float* float_zp = static_cast<const float*>(zero_points_data);
size_t k_blocks = (K_ + block_size_ - 1) / block_size_;
size_t packed_k = k_blocks * block_size_;
size_t bytes_per_col = packed_k / 4;
for (size_t n = 0; n < N_; n++) {
for (size_t k = 0; k < K_; k++) {
size_t block_idx = k / block_size_;
float scale = scales_data[n * k_blocks + block_idx];
float zp = float_zp[n * k_blocks + block_idx];
size_t packed_idx = n * bytes_per_col + k / 4;
int bit_offset = static_cast<int>((k % 4) * 2);
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
tmp_b_data_ptr.get()[n * K_ + k] =
(static_cast<float>(q) - zp) * scale;
}
}
"g_idx (reorder index) is not supported for 2-bit quantization with floating-point zero points");
DequantizeBlockwise2Bits<float, float>(
Comment thread
tianleiwu marked this conversation as resolved.
tmp_b_data_ptr.get(),
b_data,
scales_data,
static_cast<const float*>(zero_points_data),
static_cast<int32_t>(block_size_),
column_wise_quant_,
static_cast<int32_t>(K_),
static_cast<int32_t>(N_),
thread_pool);
} else {
DequantizeBlockwise<float, float>(
tmp_b_data_ptr.get(), // dequantized output
Expand Down Expand Up @@ -1092,30 +1084,22 @@ Status MatMulNBits<MLFloat16>::ComputeBUnpacked(const Tensor* a,
"Only 2b and 4b quantization is supported for unpacked compute using "
"non-MLAS de-quantization for now");

// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
// Note: The kernel registration constrains T3 to {uint8_t, T1}, so for
// MatMulNBits<MLFloat16> only MLFloat16 (not float) ZP can reach this branch.
if (zero_points && zero_points->IsDataType<MLFloat16>()) {
if (nbits_ == 2) {
ORT_ENFORCE(reorder_idx_data == nullptr,
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
// Simple 2-bit dequantization with MLFloat16 zero points
const MLFloat16* fp16_zp = static_cast<const MLFloat16*>(zero_points_data);
size_t k_blocks = (K_ + block_size_ - 1) / block_size_;
size_t packed_k = k_blocks * block_size_;
size_t bytes_per_col = packed_k / 4;
for (size_t n = 0; n < N_; n++) {
for (size_t k = 0; k < K_; k++) {
size_t block_idx = k / block_size_;
float scale = scales_ptr[n * k_blocks + block_idx];
float zp = fp16_zp[n * k_blocks + block_idx].ToFloat();
size_t packed_idx = n * bytes_per_col + k / 4;
int bit_offset = static_cast<int>((k % 4) * 2);
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
tmp_b_data_ptr.get()[n * K_ + k] =
(static_cast<float>(q) - zp) * scale;
}
}
"g_idx (reorder index) is not supported for 2-bit quantization with floating-point zero points");
DequantizeBlockwise2Bits<float, MLFloat16>(
tmp_b_data_ptr.get(),
b_data,
scales_ptr,
static_cast<const MLFloat16*>(zero_points_data),
static_cast<int32_t>(block_size_),
column_wise_quant_,
static_cast<int32_t>(K_),
static_cast<int32_t>(N_),
thread_pool);
} else {
DequantizeBlockwise<float, MLFloat16>(
tmp_b_data_ptr.get(), // dequantized output
Expand Down
151 changes: 146 additions & 5 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <type_traits>

#include "core/common/common.h"
Expand Down Expand Up @@ -41,11 +42,11 @@ void Dequantize4BitsKernelReOrder(
T* output_i = output + out_y * out_cols + out_x;
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) {
const uint8_t* c = (const uint8_t*)(&quant_value);
quant_value = (uint32_t)c[0] |
(uint32_t)c[1] << 8 |
(uint32_t)c[2] << 16 |
(uint32_t)c[3] << 24;
const uint8_t* c = reinterpret_cast<const uint8_t*>(&quant_value);
quant_value = static_cast<uint32_t>(c[0]) |
static_cast<uint32_t>(c[1]) << 8 |
static_cast<uint32_t>(c[2]) << 16 |
static_cast<uint32_t>(c[3]) << 24;
}
const int remain_x = std::min(8, out_cols - out_x);
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1));
Expand Down Expand Up @@ -117,5 +118,145 @@ template void DequantizeBlockwise<float, MLFloat16, 4>(
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

// 2-bit dequantization kernel for float/MLFloat16 zero points.
// Processes 16 elements at a time (16 x 2-bit = 32 bits = one uint32_t).
// Layout: columnwise packing — elements within a column are packed consecutively,
// output[n * K + k] = (quant_value - zp) * scale
template <class T, class zeroT>
void Dequantize2BitsKernel(
T* output, const uint8_t* quant_data, const T* scale_data,
const zeroT* zero_points, int block_size,
int groups_per_threadblock, int total_groups, int N, int K,
int blockIdx_x, int threadIdx_x) {
// Each "thread" handles 16 elements (one uint32 of packed 2-bit values)
constexpr int elements_per_thread = 16;
const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * elements_per_thread) / block_size);
if (group_id >= total_groups) {
return;
}
const int k_blocks = (K + block_size - 1) / block_size;

int n_idx = group_id / k_blocks;
int kb_idx = group_id % k_blocks;
int element_offset = group_id * block_size + ((threadIdx_x * elements_per_thread) & (block_size - 1));

const int k_offset = element_offset % (k_blocks * block_size);
const int n_offset = element_offset / (k_blocks * block_size);
if (n_offset >= N || k_offset >= K) {
return;
}

T* output_i = output + n_offset * K + k_offset;
// 16 elements × 2 bits = 4 bytes. Use memcpy to avoid alignment UB.
uint32_t quant_value = 0;
std::memcpy(&quant_value, quant_data + element_offset / 4, sizeof(uint32_t));
if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) {
const uint8_t* c = reinterpret_cast<const uint8_t*>(&quant_value);
quant_value = static_cast<uint32_t>(c[0]) |
static_cast<uint32_t>(c[1]) << 8 |
static_cast<uint32_t>(c[2]) << 16 |
static_cast<uint32_t>(c[3]) << 24;
}
const int remain_k = std::min(elements_per_thread, K - k_offset);

float scale_f = static_cast<float>(*(scale_data + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx)));
float zp_f = 0.0f;
if (zero_points) {
if constexpr (std::is_same_v<zeroT, MLFloat16>) {
zp_f = (*(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx))).ToFloat();
} else {
zp_f = static_cast<float>(*(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx)));
}
}

float zp_adjust = -scale_f * zp_f;
for (int i = 0; i < remain_k; i++) {
float q = static_cast<float>((quant_value >> (2 * i)) & 0x3);
output_i[i] = static_cast<T>(q * scale_f + zp_adjust);
}
}

template <class T, class zeroT>
void Dequantize2BitsFallback(
T* output, const uint8_t* quant_data, const T* scale_data,
const zeroT* zero_points, int block_size, int N, int K) {
const int k_blocks = (K + block_size - 1) / block_size;

for (int n = 0; n < N; ++n) {
for (int kb = 0; kb < k_blocks; ++kb) {
const int group_offset = (n * k_blocks + kb) * block_size;
const int k_start = kb * block_size;
const int k_count = std::min(block_size, K - k_start);

const float scale = static_cast<float>(scale_data[static_cast<uint64_t>(n) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb)]);
float zp_f = 0.0f;
if (zero_points) {
if constexpr (std::is_same_v<zeroT, MLFloat16>) {
zp_f = zero_points[static_cast<uint64_t>(n) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb)].ToFloat();
} else {
zp_f = static_cast<float>(zero_points[static_cast<uint64_t>(n) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb)]);
}
}
const float zp_adjust = -scale * zp_f;
T* output_i = output + static_cast<uint64_t>(n) * static_cast<uint64_t>(K) + static_cast<uint64_t>(k_start);

for (int i = 0; i < k_count; ++i) {
const int element_offset = group_offset + i;
const uint8_t packed = quant_data[element_offset / 4];
const uint8_t q = (packed >> (2 * (element_offset & 0x3))) & 0x3;
output_i[i] = static_cast<T>(static_cast<float>(q) * scale + zp_adjust);
}
}
}
}

// Specialization of DequantizeBlockwise for qbits=2
template <typename inputT, typename zeroT>
void DequantizeBlockwise2Bits(
inputT* output,
const uint8_t* quant_data,
const inputT* scales_data,
const zeroT* zero_points,
int32_t block_size,
bool columnwise,
int32_t K,
int32_t N,
onnxruntime::concurrency::ThreadPool* pool) {
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
constexpr int elements_per_thread = 16;
ORT_ENFORCE(columnwise, "Row-wise quantization is not supported");
ORT_ENFORCE(block_size > 0, "block_size must be positive, got: ", block_size);
ORT_ENFORCE((block_size & (block_size - 1)) == 0, "block_size must be a power of two, got: ", block_size);
if (block_size > 256 * elements_per_thread || block_size % elements_per_thread != 0) {
Dequantize2BitsFallback(output, quant_data, scales_data, zero_points, block_size, N, K);
return;
}

int groups_per_threadblock = 256 * elements_per_thread / block_size;
int groups_per_K = ceildiv(K, block_size);
int total_groups = N * groups_per_K;
int blocks_per_grid = static_cast<int>(ceildiv(total_groups, groups_per_threadblock));
concurrency::ThreadPool::TrySimpleParallelFor(
pool, static_cast<std::ptrdiff_t>(blocks_per_grid),
[&](std::ptrdiff_t block_id) {
for (int j = 0; j < 256; j++) {
Comment thread
tianleiwu marked this conversation as resolved.
Dequantize2BitsKernel(output, quant_data, scales_data, zero_points,
block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
}
});
}

// Explicit instantiations for 2-bit dequantization
template void DequantizeBlockwise2Bits<float, float>(
float* output, const uint8_t* quant_data, const float* scales_data,
const float* zero_points, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise2Bits<float, MLFloat16>(
float* output, const uint8_t* quant_data, const float* scales_data,
const MLFloat16* zero_points, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

} // namespace contrib
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,19 @@ void DequantizeBlockwise(
int32_t N, // number of columns in quantized input
onnxruntime::concurrency::ThreadPool* thread_pool);

// Threaded 2-bit blockwise dequantization with float/MLFloat16 zero points.
// Does not support reorder_idx (g_idx).
template <typename inputT, typename zeroT>
void DequantizeBlockwise2Bits(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
const inputT* scales_data, // quantization scales
const zeroT* zero_points, // quantization zero points
int32_t block_size, // quantization block size
bool columnwise, // columnwise quantization or row-wise
int32_t K, // number of rows in quantized input
int32_t N, // number of columns in quantized input
onnxruntime::concurrency::ThreadPool* thread_pool);
Comment thread
tianleiwu marked this conversation as resolved.

} // namespace contrib
} // namespace onnxruntime
Loading
Loading