Skip to content

Commit e91a5c3

Browse files
authored
Extend DQ→MatMulNBits fusion to support 2/8-bit weights and Cast(fp16→fp32) patterns (microsoft#27614)
### Description Extends the QDQ selector-action `DQ → MatMul → MatMulNBits` fusion in two ways: **1. Support 2-bit and 8-bit quantized weights** The existing fusion only handled 4-bit (`Int4x2`/`UInt4x2`) DQ weights. This PR broadens it to also support 2-bit (`Int2x4`/`UInt2x4`) and 8-bit (`int8`/`uint8`) quantized weights. - qdq_selectors.cc: Added `Is2BitIntType`, `Is8BitIntType`, and `IsNBitsIntType` helpers. Updated `DQMatMulNodeGroupSelector::Check` to accept 2/4/8-bit weight types. - qdq_actions.cc: Added `DQWeightBits` and `IsDQWeightSigned` helpers to dispatch the correct bit-width and signedness for MLAS transpose and MatMulNBits attributes. - `q4_dq.cpp` (MLAS): Added 8-bit `GetElem`/`SetElem` specializations and an 8-bit `TransposeColumnWiseQuantized` path. Added 6 new template instantiations for 2-bit (signed/unsigned, float/float16) and 8-bit (signed/unsigned, float/float16). **2. Handle `Cast(fp16→fp32)` between DQ and MatMul (FP16 model fusion)** FP16 models often have `DQ(int4→fp16) → Cast(fp16→fp32) → MatMul(fp32)` patterns that the existing selector couldn't match. This PR adds a new `DQCastMatMulToMatMulNBitsSelector` / `DQCastMatMulToMatMulNBitsAction` pair that: - Matches the `DQ → Cast(fp16→fp32) → MatMul` pattern on input B. - Creates a `MatMulNBits` node operating in the DQ scale dtype (fp16). - Always inserts `Cast` on input A (to DQ dtype) and `Cast` on output (DQ dtype to MatMul output dtype), relying on ORT's existing `CastElimination` optimizer to remove redundant back-to-back casts in subsequent passes. - Removes the original DQ, Cast (on B), and MatMul nodes. ### Motivation and Context - Many quantized models (e.g., from Olive, AutoAWQ) use 2-bit or 8-bit quantization, but the `DQ → MatMulNBits` fusion only supported 4-bit weights, leaving these models unoptimized. - FP16 models produce `DQ(→fp16) → Cast(fp16→fp32) → MatMul` patterns because the DQ output type matches the scale type (fp16), but the MatMul operates in fp32. Without handling the intermediate Cast, the fusion was blocked entirely for these models.
1 parent b8f5f1a commit e91a5c3

7 files changed

Lines changed: 1034 additions & 179 deletions

File tree

onnxruntime/core/mlas/lib/q4_dq.cpp

Lines changed: 274 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,9 @@ struct BlockwiseQDQQuantizer {
663663
return (val >> (idx << 1)) & 0x3;
664664
} else if constexpr (qbits == 4) {
665665
return (val >> (idx << 2)) & 0xF;
666+
} else if constexpr (qbits == 8) {
667+
(void)idx;
668+
return val;
666669
}
667670
}
668671

@@ -674,6 +677,10 @@ struct BlockwiseQDQQuantizer {
674677
} else if constexpr (qbits == 4) {
675678
auto shift = idx << 2;
676679
return ((val & 0xF) << shift) | (dst & (~(0xF << shift)));
680+
} else if constexpr (qbits == 8) {
681+
(void)idx;
682+
(void)dst;
683+
return val;
677684
}
678685
}
679686

@@ -813,21 +820,185 @@ struct BlockwiseQDQQuantizer {
813820
src_zero_points || signed_quant || dst_zero_points,
814821
"Unsigned quant types without zero points must allocate zero points with value 0."
815822
);
816-
// Must avoid multiple thread write to a single byte, which means the starting index
817-
// of a thread block must be even. To achieve that, we need to customize the thread
818-
// block size based on the parity of columns.
819-
if (columns & 1) {
820-
TransposeColumnWiseQuantizedPackUnaligned(
821-
src_weights, src_scales, src_zero_points,
822-
dst_weights, dst_scales, dst_zero_points,
823-
rows, columns, quant_block_size, thread_pool
823+
824+
if constexpr (qbits == 8) {
825+
// 8-bit: each element is one byte, no sub-byte packing needed.
826+
// Simple byte-level transpose from [rows, columns] to [columns, k_blocks, block_size].
827+
auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size;
828+
auto dst_bytes_per_quant_blk = quant_block_size; // 8 bits = 1 byte per element
829+
auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk;
830+
831+
// Transpose weights: src [rows, columns] -> dst [columns, k_blocks, block_size]
832+
MlasTryBatchParallel(
833+
thread_pool, static_cast<ptrdiff_t>(row_quant_blk_num * columns),
834+
[&](ptrdiff_t thread_blk_idx) {
835+
auto row_blk = static_cast<int32_t>(thread_blk_idx / columns);
836+
auto col = static_cast<int32_t>(thread_blk_idx % columns);
837+
838+
auto src_row_start = row_blk * quant_block_size;
839+
auto src_row_end = std::min(src_row_start + quant_block_size, rows);
840+
841+
auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk;
842+
for (auto r = src_row_start; r < src_row_end; ++r) {
843+
auto src_val = src_weights[r * columns + col];
844+
if constexpr (signed_quant) {
845+
src_val ^= 0x80; // INT8 -> UINT8: add 128
846+
}
847+
dst_weights[dst_base + (r - src_row_start)] = src_val;
848+
}
849+
// Zero-pad remaining bytes in the last block if rows % block_size != 0
850+
for (auto r = src_row_end - src_row_start; r < quant_block_size; ++r) {
851+
dst_weights[dst_base + r] = signed_quant ? 0x80 : 0;
852+
}
853+
}
824854
);
825-
} else {
826-
TransposeColumnWiseQuantizedPackAligned(
827-
src_weights, src_scales, src_zero_points,
828-
dst_weights, dst_scales, dst_zero_points,
829-
rows, columns, quant_block_size, thread_pool
855+
856+
// Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks]
857+
MlasTryBatchParallel(
858+
thread_pool, static_cast<ptrdiff_t>(columns),
859+
[&](ptrdiff_t col) {
860+
auto src_idx = static_cast<int32_t>(col);
861+
auto dst_idx = static_cast<int32_t>(col) * row_quant_blk_num;
862+
for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
863+
dst_scales[dst_idx] = src_scales[src_idx];
864+
}
865+
}
830866
);
867+
868+
// Transpose zero points: src [k_blocks, columns] -> dst [columns, k_blocks]
869+
// For 8-bit, zero points are byte-aligned (1 byte each), no packing needed.
870+
if (src_zero_points && dst_zero_points) {
871+
MlasTryBatchParallel(
872+
thread_pool, static_cast<ptrdiff_t>(columns),
873+
[&](ptrdiff_t col) {
874+
auto src_idx = static_cast<int32_t>(col);
875+
auto dst_idx = static_cast<int32_t>(col) * row_quant_blk_num;
876+
for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
877+
auto zp = src_zero_points[src_idx];
878+
if constexpr (signed_quant) {
879+
zp ^= 0x80; // INT8 -> UINT8
880+
}
881+
dst_zero_points[dst_idx] = zp;
882+
}
883+
}
884+
);
885+
}
886+
} else if constexpr (qbits == 2) {
887+
// 2-bit: 4 elements per byte. Element-by-element transpose.
888+
constexpr int32_t kPackSize = 4;
889+
auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size;
890+
auto packed_src_cols = (columns + kPackSize - 1) / kPackSize;
891+
auto dst_bytes_per_quant_blk = (quant_block_size + kPackSize - 1) / kPackSize;
892+
auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk;
893+
894+
// Transpose weights: src [rows, ceil(columns/4)] -> dst [columns, k_blocks, ceil(block_size/4)]
895+
// Each thread handles one (row_block, column) pair writing to non-overlapping dst ranges.
896+
MlasTryBatchParallel(
897+
thread_pool, static_cast<ptrdiff_t>(row_quant_blk_num * columns),
898+
[&](ptrdiff_t thread_blk_idx) {
899+
auto row_blk = static_cast<int32_t>(thread_blk_idx / columns);
900+
auto col = static_cast<int32_t>(thread_blk_idx % columns);
901+
902+
auto src_row_start = row_blk * quant_block_size;
903+
auto src_row_end = std::min(src_row_start + quant_block_size, rows);
904+
905+
auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk;
906+
907+
// Zero destination bytes for this block
908+
for (int32_t b = 0; b < dst_bytes_per_quant_blk; ++b) {
909+
dst_weights[dst_base + b] = 0;
910+
}
911+
912+
for (auto r = src_row_start; r < src_row_end; ++r) {
913+
// Extract 2-bit value from source
914+
auto src_byte_idx = r * packed_src_cols + col / kPackSize;
915+
auto src_bit_shift = (col % kPackSize) * 2;
916+
uint8_t val = (src_weights[src_byte_idx] >> src_bit_shift) & 0x3;
917+
918+
if constexpr (signed_quant) {
919+
val ^= 0x2; // int2[-2,1] -> uint2[0,3]
920+
}
921+
922+
// Place in destination
923+
auto r_in_blk = r - src_row_start;
924+
auto dst_byte_off = r_in_blk / kPackSize;
925+
auto dst_bit_shift = (r_in_blk % kPackSize) * 2;
926+
dst_weights[dst_base + dst_byte_off] |= (val << dst_bit_shift);
927+
}
928+
929+
// Zero-pad remaining positions (unsigned equivalent of 0)
930+
if constexpr (signed_quant) {
931+
for (auto r_in_blk = src_row_end - src_row_start;
932+
r_in_blk < quant_block_size; ++r_in_blk) {
933+
auto dst_byte_off = r_in_blk / kPackSize;
934+
auto dst_bit_shift = (r_in_blk % kPackSize) * 2;
935+
dst_weights[dst_base + dst_byte_off] |= (0x2 << dst_bit_shift);
936+
}
937+
}
938+
}
939+
);
940+
941+
// Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks]
942+
MlasTryBatchParallel(
943+
thread_pool, static_cast<ptrdiff_t>(columns),
944+
[&](ptrdiff_t col) {
945+
auto src_idx = static_cast<int32_t>(col);
946+
auto dst_idx = static_cast<int32_t>(col) * row_quant_blk_num;
947+
for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
948+
dst_scales[dst_idx] = src_scales[src_idx];
949+
}
950+
}
951+
);
952+
953+
// Transpose zero points: src [k_blocks, ceil(columns/4)] -> dst [columns, ceil(k_blocks/4)]
954+
if (src_zero_points && dst_zero_points) {
955+
auto packed_src_zp_cols = (columns + kPackSize - 1) / kPackSize;
956+
auto zp_dst_bytes_per_col = (row_quant_blk_num + kPackSize - 1) / kPackSize;
957+
958+
MlasTryBatchParallel(
959+
thread_pool, static_cast<ptrdiff_t>(columns),
960+
[&](ptrdiff_t col_idx) {
961+
auto col = static_cast<int32_t>(col_idx);
962+
auto dst_base = col * zp_dst_bytes_per_col;
963+
964+
for (int32_t b = 0; b < zp_dst_bytes_per_col; ++b) {
965+
dst_zero_points[dst_base + b] = 0;
966+
}
967+
968+
for (int32_t blk = 0; blk < row_quant_blk_num; ++blk) {
969+
auto src_byte_idx = blk * packed_src_zp_cols + col / kPackSize;
970+
auto src_bit_shift = (col % kPackSize) * 2;
971+
uint8_t val = (src_zero_points[src_byte_idx] >> src_bit_shift) & 0x3;
972+
973+
if constexpr (signed_quant) {
974+
val ^= 0x2;
975+
}
976+
977+
auto dst_byte_off = blk / kPackSize;
978+
auto dst_bit_shift = (blk % kPackSize) * 2;
979+
dst_zero_points[dst_base + dst_byte_off] |= (val << dst_bit_shift);
980+
}
981+
}
982+
);
983+
}
984+
} else {
985+
// 4-bit sub-byte types: use packing-aware transpose paths.
986+
// Must avoid multiple thread write to a single byte, which means the starting index
987+
// of a thread block must be even. To achieve that, we need to customize the thread
988+
// block size based on the parity of columns.
989+
if (columns & 1) {
990+
TransposeColumnWiseQuantizedPackUnaligned(
991+
src_weights, src_scales, src_zero_points,
992+
dst_weights, dst_scales, dst_zero_points,
993+
rows, columns, quant_block_size, thread_pool
994+
);
995+
} else {
996+
TransposeColumnWiseQuantizedPackAligned(
997+
src_weights, src_scales, src_zero_points,
998+
dst_weights, dst_scales, dst_zero_points,
999+
rows, columns, quant_block_size, thread_pool
1000+
);
1001+
}
8311002
}
8321003
}
8331004

@@ -2184,3 +2355,93 @@ MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 4, false>(
21842355
int quant_block_size,
21852356
MLAS_THREADPOOL* thread_pool
21862357
);
2358+
2359+
template void
2360+
MlasQDQTransposeBlockwiseQuantized<float, 8, true>(
2361+
const uint8_t* src_weights,
2362+
const float* src_scales,
2363+
const uint8_t* src_zero_points,
2364+
uint8_t* dst_weights,
2365+
float* dst_scales,
2366+
uint8_t* dst_zero_points,
2367+
bool columnwise,
2368+
int rows,
2369+
int columns,
2370+
int quant_block_size,
2371+
MLAS_THREADPOOL* thread_pool
2372+
);
2373+
2374+
template void
2375+
MlasQDQTransposeBlockwiseQuantized<float, 8, false>(
2376+
const uint8_t* src_weights,
2377+
const float* src_scales,
2378+
const uint8_t* src_zero_points,
2379+
uint8_t* dst_weights,
2380+
float* dst_scales,
2381+
uint8_t* dst_zero_points,
2382+
bool columnwise,
2383+
int rows,
2384+
int columns,
2385+
int quant_block_size,
2386+
MLAS_THREADPOOL* thread_pool
2387+
);
2388+
2389+
template void
2390+
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 8, true>(
2391+
const uint8_t* src_weights,
2392+
const MLAS_FP16* src_scales,
2393+
const uint8_t* src_zero_points,
2394+
uint8_t* dst_weights,
2395+
MLAS_FP16* dst_scales,
2396+
uint8_t* dst_zero_points,
2397+
bool columnwise,
2398+
int rows,
2399+
int columns,
2400+
int quant_block_size,
2401+
MLAS_THREADPOOL* thread_pool
2402+
);
2403+
2404+
template void
2405+
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 8, false>(
2406+
const uint8_t* src_weights,
2407+
const MLAS_FP16* src_scales,
2408+
const uint8_t* src_zero_points,
2409+
uint8_t* dst_weights,
2410+
MLAS_FP16* dst_scales,
2411+
uint8_t* dst_zero_points,
2412+
bool columnwise,
2413+
int rows,
2414+
int columns,
2415+
int quant_block_size,
2416+
MLAS_THREADPOOL* thread_pool
2417+
);
2418+
2419+
template void
2420+
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 2, true>(
2421+
const uint8_t* src_weights,
2422+
const MLAS_FP16* src_scales,
2423+
const uint8_t* src_zero_points,
2424+
uint8_t* dst_weights,
2425+
MLAS_FP16* dst_scales,
2426+
uint8_t* dst_zero_points,
2427+
bool columnwise,
2428+
int rows,
2429+
int columns,
2430+
int quant_block_size,
2431+
MLAS_THREADPOOL* thread_pool
2432+
);
2433+
2434+
template void
2435+
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 2, false>(
2436+
const uint8_t* src_weights,
2437+
const MLAS_FP16* src_scales,
2438+
const uint8_t* src_zero_points,
2439+
uint8_t* dst_weights,
2440+
MLAS_FP16* dst_scales,
2441+
uint8_t* dst_zero_points,
2442+
bool columnwise,
2443+
int rows,
2444+
int columns,
2445+
int quant_block_size,
2446+
MLAS_THREADPOOL* thread_pool
2447+
);

0 commit comments

Comments
 (0)