From c745e4d9b06a7372148f47531ae17471cbaa618c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 16:04:22 +0000 Subject: [PATCH 1/6] Initial plan From 321ce21f44b92dda50e56e0e1b33bae0326ad13e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 16:37:57 +0000 Subject: [PATCH 2/6] Optimize MatMulNBits 2-bit + float zero_point dequantization with multi-threaded kernel Replace the naive single-threaded scalar loop for 2-bit quantization with float/MLFloat16 zero points with a multi-threaded implementation using TrySimpleParallelFor. The new DequantizeBlockwise2Bits function processes 16 elements (one uint32 of packed 2-bit values) per iteration and distributes work across available threads, matching the parallelism pattern used by the existing 4-bit DequantizeBlockwise path. Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/76231b1d-cdea-427a-8824-29293b1d02eb Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> --- .../cpu/quantization/matmul_nbits.cc | 55 ++++------ .../cpu/quantization/matmul_nbits_impl.cc | 103 ++++++++++++++++++ .../cpu/quantization/matmul_nbits_impl.h | 14 +++ 3 files changed, 137 insertions(+), 35 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 3ad1e0c53657f..16c439b1722a3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -931,30 +931,22 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, "Only 2b and 4b quantization is supported for unpacked compute using " "non-MLAS de-quantization for now"); - // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! // Note: The kernel registration constrains T3 to {uint8_t, T1}, so for // MatMulNBits only float (not MLFloat16) ZP can reach this branch. if (zero_points && zero_points->IsDataType()) { if (nbits_ == 2) { ORT_ENFORCE(reorder_idx_data == nullptr, "g_idx (reorder index) is not supported for 2-bit quantization with float zero points"); - // Simple 2-bit dequantization with float zero points - const float* float_zp = static_cast(zero_points_data); - size_t k_blocks = (K_ + block_size_ - 1) / block_size_; - size_t packed_k = k_blocks * block_size_; - size_t bytes_per_col = packed_k / 4; - for (size_t n = 0; n < N_; n++) { - for (size_t k = 0; k < K_; k++) { - size_t block_idx = k / block_size_; - float scale = scales_data[n * k_blocks + block_idx]; - float zp = float_zp[n * k_blocks + block_idx]; - size_t packed_idx = n * bytes_per_col + k / 4; - int bit_offset = static_cast((k % 4) * 2); - uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3; - tmp_b_data_ptr.get()[n * K_ + k] = - (static_cast(q) - zp) * scale; - } - } + DequantizeBlockwise2Bits( + tmp_b_data_ptr.get(), + b_data, + scales_data, + static_cast(zero_points_data), + static_cast(block_size_), + column_wise_quant_, + static_cast(K_), + static_cast(N_), + thread_pool); } else { DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output @@ -1099,23 +1091,16 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, if (nbits_ == 2) { ORT_ENFORCE(reorder_idx_data == nullptr, "g_idx (reorder index) is not supported for 2-bit quantization with float zero points"); - // Simple 2-bit dequantization with MLFloat16 zero points - const MLFloat16* fp16_zp = static_cast(zero_points_data); - size_t k_blocks = (K_ + block_size_ - 1) / block_size_; - size_t packed_k = k_blocks * block_size_; - size_t bytes_per_col = packed_k / 4; - for (size_t n = 0; n < N_; n++) { - for (size_t k = 0; k < K_; k++) { - size_t block_idx = k / block_size_; - float scale = scales_ptr[n * k_blocks + block_idx]; - float zp = fp16_zp[n * k_blocks + block_idx].ToFloat(); - size_t packed_idx = n * bytes_per_col + k / 4; - int bit_offset = static_cast((k % 4) * 2); - uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3; - tmp_b_data_ptr.get()[n * K_ + k] = - (static_cast(q) - zp) * scale; - } - } + DequantizeBlockwise2Bits( + tmp_b_data_ptr.get(), + b_data, + scales_ptr, + static_cast(zero_points_data), + static_cast(block_size_), + column_wise_quant_, + static_cast(K_), + static_cast(N_), + thread_pool); } else { DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index e9ef220a2187e..290c70b2554aa 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -117,5 +117,108 @@ template void DequantizeBlockwise( const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); +// 2-bit dequantization kernel for float/MLFloat16 zero points. +// Processes 16 elements at a time (16 x 2-bit = 32 bits = one uint32_t). +// Layout: columnwise packing — elements within a column are packed consecutively, +// output[n * K + k] = (quant_value - zp) * scale +template +void Dequantize2BitsKernel( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, int block_size, + int groups_per_threadblock, int total_groups, int N, int K, + int blockIdx_x, int threadIdx_x) { + // Each "thread" handles 16 elements (one uint32 of packed 2-bit values) + constexpr int elements_per_thread = 16; + const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * elements_per_thread) / block_size); + if (group_id >= total_groups) { + return; + } + const int k_blocks = (K + block_size - 1) / block_size; + + int n_idx = group_id / k_blocks; + int kb_idx = group_id % k_blocks; + int element_offset = group_id * block_size + ((threadIdx_x * elements_per_thread) & (block_size - 1)); + + const int k_offset = element_offset % (k_blocks * block_size); + const int n_offset = element_offset / (k_blocks * block_size); + if (n_offset >= N || k_offset >= K) { + return; + } + + T* output_i = output + n_offset * K + k_offset; + // 16 elements × 2 bits = 4 bytes + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 4)); + if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) { + const uint8_t* c = (const uint8_t*)(&quant_value); + quant_value = (uint32_t)c[0] | + (uint32_t)c[1] << 8 | + (uint32_t)c[2] << 16 | + (uint32_t)c[3] << 24; + } + const int remain_k = std::min(elements_per_thread, K - k_offset); + + T scale = *(scale_data + static_cast(n_idx) * static_cast(k_blocks) + static_cast(kb_idx)); + float zp_f = 0.0f; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = (*(zero_points + static_cast(n_idx) * static_cast(k_blocks) + static_cast(kb_idx))).ToFloat(); + } else { + zp_f = static_cast(*(zero_points + static_cast(n_idx) * static_cast(k_blocks) + static_cast(kb_idx))); + } + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * MLFloat16(zp_f); + for (int i = 0; i < remain_k; i++) { + output_i[i] = static_cast((quant_value >> (2 * i)) & 0x3) * scale + zp_adjust; + } + } else { + T zp_adjust = -scale * zp_f; + for (int i = 0; i < remain_k; i++) { + output_i[i] = T((quant_value >> (2 * i)) & 0x3) * scale + zp_adjust; + } + } +} + +// Specialization of DequantizeBlockwise for qbits=2 +template +void DequantizeBlockwise2Bits( + inputT* output, + const uint8_t* quant_data, + const inputT* scales_data, + const zeroT* zero_points, + int32_t block_size, + bool, + int32_t K, + int32_t N, + onnxruntime::concurrency::ThreadPool* pool) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + constexpr int elements_per_thread = 16; + int groups_per_threadblock = 256 * elements_per_thread / block_size; + int groups_per_K = ceildiv(K, block_size); + int total_groups = N * groups_per_K; + int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock)); + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(blocks_per_grid), + [&](std::ptrdiff_t block_id) { + for (int j = 0; j < 256; j++) { + Dequantize2BitsKernel(output, quant_data, scales_data, zero_points, + block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } + }); +} + +// Explicit instantiations for 2-bit dequantization +template void DequantizeBlockwise2Bits( + float* output, const uint8_t* quant_data, const float* scales_data, + const float* zero_points, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +template void DequantizeBlockwise2Bits( + float* output, const uint8_t* quant_data, const float* scales_data, + const MLFloat16* zero_points, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index be77ec03d006b..7e82d1b32a025 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -19,5 +19,19 @@ void DequantizeBlockwise( int32_t N, // number of columns in quantized input onnxruntime::concurrency::ThreadPool* thread_pool); +// Threaded 2-bit blockwise dequantization with float/MLFloat16 zero points. +// Does not support reorder_idx (g_idx). +template +void DequantizeBlockwise2Bits( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* thread_pool); + } // namespace contrib } // namespace onnxruntime From 6d03671a2cf433b7fda2cfc1aab18549b01a6a9e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 20 May 2026 11:14:01 -0700 Subject: [PATCH 3/6] add benchmark --- .../cpu/quantization/matmul_nbits_impl.h | 16 +- .../python/quantization/bench_matmul_2bits.py | 318 ++++++++++++++++++ 2 files changed, 326 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/test/python/quantization/bench_matmul_2bits.py diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index 7e82d1b32a025..71cdaf0f2fb5a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -23,14 +23,14 @@ void DequantizeBlockwise( // Does not support reorder_idx (g_idx). template void DequantizeBlockwise2Bits( - inputT* output, // dequantized output - const uint8_t* quant_data, // quantized input - const inputT* scales_data, // quantization scales - const zeroT* zero_points, // quantization zero points - int32_t block_size, // quantization block size - bool, // columnwise quantization or row-wise - int32_t K, // number of rows in quantized input - int32_t N, // number of columns in quantized input + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input onnxruntime::concurrency::ThreadPool* thread_pool); } // namespace contrib diff --git a/onnxruntime/test/python/quantization/bench_matmul_2bits.py b/onnxruntime/test/python/quantization/bench_matmul_2bits.py new file mode 100644 index 0000000000000..5922ca776d884 --- /dev/null +++ b/onnxruntime/test/python/quantization/bench_matmul_2bits.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark for MatMulNBits 2-bit dequantization performance on CPU. + +This benchmark measures the performance improvement from the multi-threaded +DequantizeBlockwise2Bits kernel (PR #28589 / issue #28552) compared to +baseline. It exercises the MatMulNBits operator with 2-bit quantization +and float zero points on the CPU execution provider. + +Usage: + python bench_matmul_2bits.py [--warmup N] [--repeats N] [--threads N] +""" + +import argparse +import time + +import numpy as np +from onnx import TensorProto, helper, numpy_helper + +import onnxruntime as ort + + +def create_matmul_nbits_model( + M: int, + K: int, + N: int, + block_size: int, + bits: int = 2, + has_zero_point: bool = True, +) -> bytes: + """ + Creates an ONNX model with a single MatMulNBits node. + + The model structure: + input A [M, K] (float32) -> MatMulNBits -> output [M, N] (float32) + + With quantized weight B [N, K] packed as 2-bit or 4-bit values. + + Args: + M: Batch/sequence dimension. + K: Input features (rows of weight matrix). + N: Output features (columns of weight matrix). + block_size: Quantization block size along K. + bits: Number of quantization bits (2 or 4). + has_zero_point: Whether to include float zero points. + + Returns: + Serialized ONNX model bytes. + """ + k_blocks = (K + block_size - 1) // block_size + + # Input: A [M, K] + input_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [M, K]) + + # Output + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [M, N]) + + # Weight B: packed values as uint8, shape [N, k_blocks, blob_size] + elements_per_byte = 8 // bits # 4 for 2-bit, 2 for 4-bit + blob_size = block_size // elements_per_byte + b_data = np.random.randint(0, 256, size=(N, k_blocks, blob_size), dtype=np.uint8) + b_initializer = numpy_helper.from_array(b_data, name="B") + + # Scales: [N, k_blocks] as float32 + scales_data = np.random.uniform(0.001, 0.1, size=(N, k_blocks)).astype(np.float32) + scales_initializer = numpy_helper.from_array(scales_data, name="scales") + + initializers = [b_initializer, scales_initializer] + input_names = ["A", "B", "scales"] + + if has_zero_point: + # Float zero points: [N, k_blocks] as float32 + zp_data = np.random.uniform(0.0, 3.0, size=(N, k_blocks)).astype(np.float32) + zp_initializer = numpy_helper.from_array(zp_data, name="zero_points") + initializers.append(zp_initializer) + input_names.append("zero_points") + + # MatMulNBits node + node = helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=["output"], + name="MatMulNBits_0", + domain="com.microsoft", + bits=bits, + block_size=block_size, + K=K, + N=N, + ) + + graph = helper.make_graph( + [node], + "matmul_nbits_2bit_bench", + [input_a], + [output], + initializer=initializers, + ) + + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 9 + + return model.SerializeToString() + + +def benchmark_matmul_nbits( + M: int, + K: int, + N: int, + block_size: int, + bits: int, + num_threads: int, + warmup: int = 5, + repeats: int = 50, + has_zero_point: bool = True, +) -> dict: + """ + Benchmark MatMulNBits with n-bit quantization on CPU. + + Returns: + Dictionary with timing results. + """ + model_bytes = create_matmul_nbits_model(M, K, N, block_size, bits, has_zero_point) + + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = num_threads + sess_options.inter_op_num_threads = 1 + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + session = ort.InferenceSession( + model_bytes, + sess_options, + providers=["CPUExecutionProvider"], + ) + + # Create input + input_a = np.random.randn(M, K).astype(np.float32) + feeds = {"A": input_a} + + # Warmup + for _ in range(warmup): + session.run(None, feeds) + + # Benchmark + latencies = [] + for _ in range(repeats): + start = time.perf_counter() + session.run(None, feeds) + end = time.perf_counter() + latencies.append(end - start) + + latencies_ms = [t * 1000 for t in latencies] + return { + "M": M, + "K": K, + "N": N, + "block_size": block_size, + "bits": bits, + "threads": num_threads, + "has_zp": has_zero_point, + "mean_ms": np.mean(latencies_ms), + "median_ms": np.median(latencies_ms), + "min_ms": np.min(latencies_ms), + "max_ms": np.max(latencies_ms), + "std_ms": np.std(latencies_ms), + } + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark MatMulNBits 2-bit dequantization on CPU") + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--repeats", type=int, default=50, help="Number of benchmark iterations") + parser.add_argument( + "--threads", + type=int, + nargs="+", + default=[1, 2, 4, 8], + help="Thread counts to benchmark", + ) + parser.add_argument("--m", type=int, nargs="+", default=[1, 32], help="M dimensions (batch)") + parser.add_argument("--bits", type=int, nargs="+", default=[2, 4], help="Quantization bits to compare") + args = parser.parse_args() + + # Typical LLM weight shapes + configs = [ + # (K, N, block_size) — typical LLM layers + (4096, 4096, 128), # hidden projection + (4096, 11008, 128), # FFN up/gate + (11008, 4096, 128), # FFN down + # Smaller shapes for quick validation + (1024, 1024, 128), + (4096, 4096, 32), + ] + + print("=" * 110) + print("MatMulNBits 2-bit vs 4-bit Dequantization Benchmark (float zero points, CPU)") + print(f"ORT version: {ort.__version__}") + print(f"Warmup: {args.warmup}, Repeats: {args.repeats}") + print("=" * 110) + print() + + header = f"{'Bits':>4} {'M':>5} {'K':>6} {'N':>6} {'BS':>4} {'Thr':>4} {'Mean(ms)':>10} {'Med(ms)':>10} {'Min(ms)':>10} {'Std(ms)':>10}" + print(header) + print("-" * len(header)) + + results = [] + for k, n, block_size in configs: + for m in args.m: + for bits in args.bits: + for num_threads in args.threads: + try: + result = benchmark_matmul_nbits( + M=m, + K=k, + N=n, + block_size=block_size, + bits=bits, + num_threads=num_threads, + warmup=args.warmup, + repeats=args.repeats, + has_zero_point=True, + ) + results.append(result) + print( + f"{result['bits']:>4} {result['M']:>5} {result['K']:>6} {result['N']:>6} " + f"{result['block_size']:>4} {result['threads']:>4} " + f"{result['mean_ms']:>10.3f} {result['median_ms']:>10.3f} " + f"{result['min_ms']:>10.3f} {result['std_ms']:>10.3f}" + ) + except Exception as e: + print(f" FAILED: bits={bits} M={m} K={k} N={n} bs={block_size} threads={num_threads}: {e}") + + print() # Blank line between config groups + + # Summary: compare 2-bit vs 4-bit and show multi-thread speedup + print("\n" + "=" * 110) + print("Speedup Summary") + print("=" * 110) + + # Multi-thread speedup for 2-bit + print("\n--- 2-bit: Multi-thread speedup (vs 1 thread) ---") + header2 = f"{'M':>5} {'K':>6} {'N':>6} {'BS':>4} {'1-thr(ms)':>10} {'best-thr':>9} {'best(ms)':>10} {'Speedup':>8}" + print(header2) + print("-" * len(header2)) + + for k, n, block_size in configs: + for m in args.m: + group = [ + r + for r in results + if r["K"] == k and r["N"] == n and r["block_size"] == block_size and r["M"] == m and r["bits"] == 2 + ] + if not group: + continue + single = next((r for r in group if r["threads"] == 1), None) + if single is None: + continue + best = min(group, key=lambda r: r["median_ms"]) + speedup = single["median_ms"] / best["median_ms"] if best["median_ms"] > 0 else 0 + print( + f"{m:>5} {k:>6} {n:>6} {block_size:>4} " + f"{single['median_ms']:>10.3f} {best['threads']:>9} " + f"{best['median_ms']:>10.3f} {speedup:>7.2f}x" + ) + + # 2-bit vs 4-bit comparison (same thread count) + if 4 in args.bits and 2 in args.bits: + print("\n--- 2-bit vs 4-bit comparison (same thread count) ---") + header3 = f"{'M':>5} {'K':>6} {'N':>6} {'BS':>4} {'Thr':>4} {'4-bit(ms)':>10} {'2-bit(ms)':>10} {'Ratio':>8}" + print(header3) + print("-" * len(header3)) + + for k, n, block_size in configs: + for m in args.m: + for num_threads in args.threads: + r2 = next( + ( + r + for r in results + if r["K"] == k + and r["N"] == n + and r["block_size"] == block_size + and r["M"] == m + and r["bits"] == 2 + and r["threads"] == num_threads + ), + None, + ) + r4 = next( + ( + r + for r in results + if r["K"] == k + and r["N"] == n + and r["block_size"] == block_size + and r["M"] == m + and r["bits"] == 4 + and r["threads"] == num_threads + ), + None, + ) + if r2 and r4: + ratio = r2["median_ms"] / r4["median_ms"] if r4["median_ms"] > 0 else 0 + print( + f"{m:>5} {k:>6} {n:>6} {block_size:>4} {num_threads:>4} " + f"{r4['median_ms']:>10.3f} {r2['median_ms']:>10.3f} {ratio:>7.2f}x" + ) + + +if __name__ == "__main__": + main() From 6649f74719ee717bb9f69319fd68b3c464ad7f9f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 20 May 2026 12:23:06 -0700 Subject: [PATCH 4/6] address feedbacks --- onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | 1 - .../contrib_ops/cpu/quantization/matmul_nbits_impl.cc | 3 +++ .../test/python/quantization/bench_matmul_2bits.py | 8 +++++--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 16c439b1722a3..b822f37199f15 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1084,7 +1084,6 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, "Only 2b and 4b quantization is supported for unpacked compute using " "non-MLAS de-quantization for now"); - // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! // Note: The kernel registration constrains T3 to {uint8_t, T1}, so for // MatMulNBits only MLFloat16 (not float) ZP can reach this branch. if (zero_points && zero_points->IsDataType()) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index 290c70b2554aa..cd828d67583e0 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -194,6 +194,9 @@ void DequantizeBlockwise2Bits( onnxruntime::concurrency::ThreadPool* pool) { auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; constexpr int elements_per_thread = 16; + ORT_ENFORCE(block_size > 0 && block_size <= 256 * elements_per_thread && block_size % elements_per_thread == 0, + "block_size must be positive, at most ", 256 * elements_per_thread, + ", and a multiple of ", elements_per_thread, ", got: ", block_size); int groups_per_threadblock = 256 * elements_per_thread / block_size; int groups_per_K = ceildiv(K, block_size); int total_groups = N * groups_per_K; diff --git a/onnxruntime/test/python/quantization/bench_matmul_2bits.py b/onnxruntime/test/python/quantization/bench_matmul_2bits.py index 5922ca776d884..01b166d7667f6 100644 --- a/onnxruntime/test/python/quantization/bench_matmul_2bits.py +++ b/onnxruntime/test/python/quantization/bench_matmul_2bits.py @@ -7,10 +7,12 @@ """ Benchmark for MatMulNBits 2-bit dequantization performance on CPU. -This benchmark measures the performance improvement from the multi-threaded -DequantizeBlockwise2Bits kernel (PR #28589 / issue #28552) compared to -baseline. It exercises the MatMulNBits operator with 2-bit quantization +This benchmark measures the performance of the multi-threaded +DequantizeBlockwise2Bits kernel (PR #28589 / issue #28552). +It exercises the MatMulNBits operator with 2-bit quantization and float zero points on the CPU execution provider. +To compare against a baseline, run this script on two different builds +and compare the reported latencies. Usage: python bench_matmul_2bits.py [--warmup N] [--repeats N] [--threads N] From cfe5ce8f69d0ff8e7db2a205e59620f9fc481f83 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 22 May 2026 18:20:07 +0000 Subject: [PATCH 5/6] Address MatMulNBits review feedback --- .../cpu/quantization/matmul_nbits.cc | 4 +- .../cpu/quantization/matmul_nbits_impl.cc | 47 +++++++++++++++++-- .../cpu/quantization/matmul_nbits_impl.h | 2 +- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index b822f37199f15..3979977557d2d 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -936,7 +936,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, if (zero_points && zero_points->IsDataType()) { if (nbits_ == 2) { ORT_ENFORCE(reorder_idx_data == nullptr, - "g_idx (reorder index) is not supported for 2-bit quantization with float zero points"); + "g_idx (reorder index) is not supported for 2-bit quantization with floating-point zero points"); DequantizeBlockwise2Bits( tmp_b_data_ptr.get(), b_data, @@ -1089,7 +1089,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, if (zero_points && zero_points->IsDataType()) { if (nbits_ == 2) { ORT_ENFORCE(reorder_idx_data == nullptr, - "g_idx (reorder index) is not supported for 2-bit quantization with float zero points"); + "g_idx (reorder index) is not supported for 2-bit quantization with floating-point zero points"); DequantizeBlockwise2Bits( tmp_b_data_ptr.get(), b_data, diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index cd828d67583e0..cfde6d233c51c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -180,6 +180,40 @@ void Dequantize2BitsKernel( } } +template +void Dequantize2BitsFallback( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, int block_size, int N, int K) { + const int k_blocks = (K + block_size - 1) / block_size; + + for (int n = 0; n < N; ++n) { + for (int kb = 0; kb < k_blocks; ++kb) { + const int group_offset = (n * k_blocks + kb) * block_size; + const int k_start = kb * block_size; + const int k_count = std::min(block_size, K - k_start); + + const T scale = scale_data[static_cast(n) * static_cast(k_blocks) + static_cast(kb)]; + float zp_f = 0.0f; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = zero_points[static_cast(n) * static_cast(k_blocks) + static_cast(kb)].ToFloat(); + } else { + zp_f = static_cast(zero_points[static_cast(n) * static_cast(k_blocks) + static_cast(kb)]); + } + } + const T zp_adjust = -scale * zp_f; + T* output_i = output + static_cast(n) * static_cast(K) + static_cast(k_start); + + for (int i = 0; i < k_count; ++i) { + const int element_offset = group_offset + i; + const uint8_t packed = quant_data[element_offset / 4]; + const uint8_t q = (packed >> (2 * (element_offset & 0x3))) & 0x3; + output_i[i] = static_cast(q) * scale + zp_adjust; + } + } + } +} + // Specialization of DequantizeBlockwise for qbits=2 template void DequantizeBlockwise2Bits( @@ -188,15 +222,20 @@ void DequantizeBlockwise2Bits( const inputT* scales_data, const zeroT* zero_points, int32_t block_size, - bool, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* pool) { auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; constexpr int elements_per_thread = 16; - ORT_ENFORCE(block_size > 0 && block_size <= 256 * elements_per_thread && block_size % elements_per_thread == 0, - "block_size must be positive, at most ", 256 * elements_per_thread, - ", and a multiple of ", elements_per_thread, ", got: ", block_size); + ORT_ENFORCE(columnwise, "Row-wise quantization is not supported"); + ORT_ENFORCE(block_size > 0, "block_size must be positive, got: ", block_size); + ORT_ENFORCE((block_size & (block_size - 1)) == 0, "block_size must be a power of two, got: ", block_size); + if (block_size > 256 * elements_per_thread || block_size % elements_per_thread != 0) { + Dequantize2BitsFallback(output, quant_data, scales_data, zero_points, block_size, N, K); + return; + } + int groups_per_threadblock = 256 * elements_per_thread / block_size; int groups_per_K = ceildiv(K, block_size); int total_groups = N * groups_per_K; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index 71cdaf0f2fb5a..864e36ccf95d7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -28,7 +28,7 @@ void DequantizeBlockwise2Bits( const inputT* scales_data, // quantization scales const zeroT* zero_points, // quantization zero points int32_t block_size, // quantization block size - bool, // columnwise quantization or row-wise + bool columnwise, // columnwise quantization or row-wise int32_t K, // number of rows in quantized input int32_t N, // number of columns in quantized input onnxruntime::concurrency::ThreadPool* thread_pool); From 4eda0c2cf90a88e2d689f18f325b2f82e5c8cbb2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 22 May 2026 21:49:41 +0000 Subject: [PATCH 6/6] Fix issues: C-style casts; Alignment/strict-aliasing; MLFloat16 precision --- .../cpu/quantization/matmul_nbits_impl.cc | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index cfde6d233c51c..73ea5bdfc958f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "core/common/common.h" @@ -41,11 +42,11 @@ void Dequantize4BitsKernelReOrder( T* output_i = output + out_y * out_cols + out_x; uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) { - const uint8_t* c = (const uint8_t*)(&quant_value); - quant_value = (uint32_t)c[0] | - (uint32_t)c[1] << 8 | - (uint32_t)c[2] << 16 | - (uint32_t)c[3] << 24; + const uint8_t* c = reinterpret_cast(&quant_value); + quant_value = static_cast(c[0]) | + static_cast(c[1]) << 8 | + static_cast(c[2]) << 16 | + static_cast(c[3]) << 24; } const int remain_x = std::min(8, out_cols - out_x); const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1)); @@ -146,18 +147,19 @@ void Dequantize2BitsKernel( } T* output_i = output + n_offset * K + k_offset; - // 16 elements × 2 bits = 4 bytes - uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 4)); + // 16 elements × 2 bits = 4 bytes. Use memcpy to avoid alignment UB. + uint32_t quant_value = 0; + std::memcpy(&quant_value, quant_data + element_offset / 4, sizeof(uint32_t)); if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) { - const uint8_t* c = (const uint8_t*)(&quant_value); - quant_value = (uint32_t)c[0] | - (uint32_t)c[1] << 8 | - (uint32_t)c[2] << 16 | - (uint32_t)c[3] << 24; + const uint8_t* c = reinterpret_cast(&quant_value); + quant_value = static_cast(c[0]) | + static_cast(c[1]) << 8 | + static_cast(c[2]) << 16 | + static_cast(c[3]) << 24; } const int remain_k = std::min(elements_per_thread, K - k_offset); - T scale = *(scale_data + static_cast(n_idx) * static_cast(k_blocks) + static_cast(kb_idx)); + float scale_f = static_cast(*(scale_data + static_cast(n_idx) * static_cast(k_blocks) + static_cast(kb_idx))); float zp_f = 0.0f; if (zero_points) { if constexpr (std::is_same_v) { @@ -167,16 +169,10 @@ void Dequantize2BitsKernel( } } - if constexpr (std::is_same_v) { - T zp_adjust = -scale * MLFloat16(zp_f); - for (int i = 0; i < remain_k; i++) { - output_i[i] = static_cast((quant_value >> (2 * i)) & 0x3) * scale + zp_adjust; - } - } else { - T zp_adjust = -scale * zp_f; - for (int i = 0; i < remain_k; i++) { - output_i[i] = T((quant_value >> (2 * i)) & 0x3) * scale + zp_adjust; - } + float zp_adjust = -scale_f * zp_f; + for (int i = 0; i < remain_k; i++) { + float q = static_cast((quant_value >> (2 * i)) & 0x3); + output_i[i] = static_cast(q * scale_f + zp_adjust); } } @@ -192,7 +188,7 @@ void Dequantize2BitsFallback( const int k_start = kb * block_size; const int k_count = std::min(block_size, K - k_start); - const T scale = scale_data[static_cast(n) * static_cast(k_blocks) + static_cast(kb)]; + const float scale = static_cast(scale_data[static_cast(n) * static_cast(k_blocks) + static_cast(kb)]); float zp_f = 0.0f; if (zero_points) { if constexpr (std::is_same_v) { @@ -201,14 +197,14 @@ void Dequantize2BitsFallback( zp_f = static_cast(zero_points[static_cast(n) * static_cast(k_blocks) + static_cast(kb)]); } } - const T zp_adjust = -scale * zp_f; + const float zp_adjust = -scale * zp_f; T* output_i = output + static_cast(n) * static_cast(K) + static_cast(k_start); for (int i = 0; i < k_count; ++i) { const int element_offset = group_offset + i; const uint8_t packed = quant_data[element_offset / 4]; const uint8_t q = (packed >> (2 * (element_offset & 0x3))) & 0x3; - output_i[i] = static_cast(q) * scale + zp_adjust; + output_i[i] = static_cast(static_cast(q) * scale + zp_adjust); } } }