@@ -663,6 +663,9 @@ struct BlockwiseQDQQuantizer {
663663 return (val >> (idx << 1 )) & 0x3 ;
664664 } else if constexpr (qbits == 4 ) {
665665 return (val >> (idx << 2 )) & 0xF ;
666+ } else if constexpr (qbits == 8 ) {
667+ (void )idx;
668+ return val;
666669 }
667670 }
668671
@@ -674,6 +677,10 @@ struct BlockwiseQDQQuantizer {
674677 } else if constexpr (qbits == 4 ) {
675678 auto shift = idx << 2 ;
676679 return ((val & 0xF ) << shift) | (dst & (~(0xF << shift)));
680+ } else if constexpr (qbits == 8 ) {
681+ (void )idx;
682+ (void )dst;
683+ return val;
677684 }
678685 }
679686
@@ -813,21 +820,185 @@ struct BlockwiseQDQQuantizer {
813820 src_zero_points || signed_quant || dst_zero_points,
814821 " Unsigned quant types without zero points must allocate zero points with value 0."
815822 );
816- // Must avoid multiple thread write to a single byte, which means the starting index
817- // of a thread block must be even. To achieve that, we need to customize the thread
818- // block size based on the parity of columns.
819- if (columns & 1 ) {
820- TransposeColumnWiseQuantizedPackUnaligned (
821- src_weights, src_scales, src_zero_points,
822- dst_weights, dst_scales, dst_zero_points,
823- rows, columns, quant_block_size, thread_pool
823+
824+ if constexpr (qbits == 8 ) {
825+ // 8-bit: each element is one byte, no sub-byte packing needed.
826+ // Simple byte-level transpose from [rows, columns] to [columns, k_blocks, block_size].
827+ auto row_quant_blk_num = (rows + quant_block_size - 1 ) / quant_block_size;
828+ auto dst_bytes_per_quant_blk = quant_block_size; // 8 bits = 1 byte per element
829+ auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk;
830+
831+ // Transpose weights: src [rows, columns] -> dst [columns, k_blocks, block_size]
832+ MlasTryBatchParallel (
833+ thread_pool, static_cast <ptrdiff_t >(row_quant_blk_num * columns),
834+ [&](ptrdiff_t thread_blk_idx) {
835+ auto row_blk = static_cast <int32_t >(thread_blk_idx / columns);
836+ auto col = static_cast <int32_t >(thread_blk_idx % columns);
837+
838+ auto src_row_start = row_blk * quant_block_size;
839+ auto src_row_end = std::min (src_row_start + quant_block_size, rows);
840+
841+ auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk;
842+ for (auto r = src_row_start; r < src_row_end; ++r) {
843+ auto src_val = src_weights[r * columns + col];
844+ if constexpr (signed_quant) {
845+ src_val ^= 0x80 ; // INT8 -> UINT8: add 128
846+ }
847+ dst_weights[dst_base + (r - src_row_start)] = src_val;
848+ }
849+ // Zero-pad remaining bytes in the last block if rows % block_size != 0
850+ for (auto r = src_row_end - src_row_start; r < quant_block_size; ++r) {
851+ dst_weights[dst_base + r] = signed_quant ? 0x80 : 0 ;
852+ }
853+ }
824854 );
825- } else {
826- TransposeColumnWiseQuantizedPackAligned (
827- src_weights, src_scales, src_zero_points,
828- dst_weights, dst_scales, dst_zero_points,
829- rows, columns, quant_block_size, thread_pool
855+
856+ // Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks]
857+ MlasTryBatchParallel (
858+ thread_pool, static_cast <ptrdiff_t >(columns),
859+ [&](ptrdiff_t col) {
860+ auto src_idx = static_cast <int32_t >(col);
861+ auto dst_idx = static_cast <int32_t >(col) * row_quant_blk_num;
862+ for (int32_t i = 0 ; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
863+ dst_scales[dst_idx] = src_scales[src_idx];
864+ }
865+ }
830866 );
867+
868+ // Transpose zero points: src [k_blocks, columns] -> dst [columns, k_blocks]
869+ // For 8-bit, zero points are byte-aligned (1 byte each), no packing needed.
870+ if (src_zero_points && dst_zero_points) {
871+ MlasTryBatchParallel (
872+ thread_pool, static_cast <ptrdiff_t >(columns),
873+ [&](ptrdiff_t col) {
874+ auto src_idx = static_cast <int32_t >(col);
875+ auto dst_idx = static_cast <int32_t >(col) * row_quant_blk_num;
876+ for (int32_t i = 0 ; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
877+ auto zp = src_zero_points[src_idx];
878+ if constexpr (signed_quant) {
879+ zp ^= 0x80 ; // INT8 -> UINT8
880+ }
881+ dst_zero_points[dst_idx] = zp;
882+ }
883+ }
884+ );
885+ }
886+ } else if constexpr (qbits == 2 ) {
887+ // 2-bit: 4 elements per byte. Element-by-element transpose.
888+ constexpr int32_t kPackSize = 4 ;
889+ auto row_quant_blk_num = (rows + quant_block_size - 1 ) / quant_block_size;
890+ auto packed_src_cols = (columns + kPackSize - 1 ) / kPackSize ;
891+ auto dst_bytes_per_quant_blk = (quant_block_size + kPackSize - 1 ) / kPackSize ;
892+ auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk;
893+
894+ // Transpose weights: src [rows, ceil(columns/4)] -> dst [columns, k_blocks, ceil(block_size/4)]
895+ // Each thread handles one (row_block, column) pair writing to non-overlapping dst ranges.
896+ MlasTryBatchParallel (
897+ thread_pool, static_cast <ptrdiff_t >(row_quant_blk_num * columns),
898+ [&](ptrdiff_t thread_blk_idx) {
899+ auto row_blk = static_cast <int32_t >(thread_blk_idx / columns);
900+ auto col = static_cast <int32_t >(thread_blk_idx % columns);
901+
902+ auto src_row_start = row_blk * quant_block_size;
903+ auto src_row_end = std::min (src_row_start + quant_block_size, rows);
904+
905+ auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk;
906+
907+ // Zero destination bytes for this block
908+ for (int32_t b = 0 ; b < dst_bytes_per_quant_blk; ++b) {
909+ dst_weights[dst_base + b] = 0 ;
910+ }
911+
912+ for (auto r = src_row_start; r < src_row_end; ++r) {
913+ // Extract 2-bit value from source
914+ auto src_byte_idx = r * packed_src_cols + col / kPackSize ;
915+ auto src_bit_shift = (col % kPackSize ) * 2 ;
916+ uint8_t val = (src_weights[src_byte_idx] >> src_bit_shift) & 0x3 ;
917+
918+ if constexpr (signed_quant) {
919+ val ^= 0x2 ; // int2[-2,1] -> uint2[0,3]
920+ }
921+
922+ // Place in destination
923+ auto r_in_blk = r - src_row_start;
924+ auto dst_byte_off = r_in_blk / kPackSize ;
925+ auto dst_bit_shift = (r_in_blk % kPackSize ) * 2 ;
926+ dst_weights[dst_base + dst_byte_off] |= (val << dst_bit_shift);
927+ }
928+
929+ // Zero-pad remaining positions (unsigned equivalent of 0)
930+ if constexpr (signed_quant) {
931+ for (auto r_in_blk = src_row_end - src_row_start;
932+ r_in_blk < quant_block_size; ++r_in_blk) {
933+ auto dst_byte_off = r_in_blk / kPackSize ;
934+ auto dst_bit_shift = (r_in_blk % kPackSize ) * 2 ;
935+ dst_weights[dst_base + dst_byte_off] |= (0x2 << dst_bit_shift);
936+ }
937+ }
938+ }
939+ );
940+
941+ // Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks]
942+ MlasTryBatchParallel (
943+ thread_pool, static_cast <ptrdiff_t >(columns),
944+ [&](ptrdiff_t col) {
945+ auto src_idx = static_cast <int32_t >(col);
946+ auto dst_idx = static_cast <int32_t >(col) * row_quant_blk_num;
947+ for (int32_t i = 0 ; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
948+ dst_scales[dst_idx] = src_scales[src_idx];
949+ }
950+ }
951+ );
952+
953+ // Transpose zero points: src [k_blocks, ceil(columns/4)] -> dst [columns, ceil(k_blocks/4)]
954+ if (src_zero_points && dst_zero_points) {
955+ auto packed_src_zp_cols = (columns + kPackSize - 1 ) / kPackSize ;
956+ auto zp_dst_bytes_per_col = (row_quant_blk_num + kPackSize - 1 ) / kPackSize ;
957+
958+ MlasTryBatchParallel (
959+ thread_pool, static_cast <ptrdiff_t >(columns),
960+ [&](ptrdiff_t col_idx) {
961+ auto col = static_cast <int32_t >(col_idx);
962+ auto dst_base = col * zp_dst_bytes_per_col;
963+
964+ for (int32_t b = 0 ; b < zp_dst_bytes_per_col; ++b) {
965+ dst_zero_points[dst_base + b] = 0 ;
966+ }
967+
968+ for (int32_t blk = 0 ; blk < row_quant_blk_num; ++blk) {
969+ auto src_byte_idx = blk * packed_src_zp_cols + col / kPackSize ;
970+ auto src_bit_shift = (col % kPackSize ) * 2 ;
971+ uint8_t val = (src_zero_points[src_byte_idx] >> src_bit_shift) & 0x3 ;
972+
973+ if constexpr (signed_quant) {
974+ val ^= 0x2 ;
975+ }
976+
977+ auto dst_byte_off = blk / kPackSize ;
978+ auto dst_bit_shift = (blk % kPackSize ) * 2 ;
979+ dst_zero_points[dst_base + dst_byte_off] |= (val << dst_bit_shift);
980+ }
981+ }
982+ );
983+ }
984+ } else {
985+ // 4-bit sub-byte types: use packing-aware transpose paths.
986+ // Must avoid multiple thread write to a single byte, which means the starting index
987+ // of a thread block must be even. To achieve that, we need to customize the thread
988+ // block size based on the parity of columns.
989+ if (columns & 1 ) {
990+ TransposeColumnWiseQuantizedPackUnaligned (
991+ src_weights, src_scales, src_zero_points,
992+ dst_weights, dst_scales, dst_zero_points,
993+ rows, columns, quant_block_size, thread_pool
994+ );
995+ } else {
996+ TransposeColumnWiseQuantizedPackAligned (
997+ src_weights, src_scales, src_zero_points,
998+ dst_weights, dst_scales, dst_zero_points,
999+ rows, columns, quant_block_size, thread_pool
1000+ );
1001+ }
8311002 }
8321003 }
8331004
@@ -2184,3 +2355,93 @@ MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 4, false>(
21842355 int quant_block_size,
21852356 MLAS_THREADPOOL* thread_pool
21862357);
2358+
2359+ template void
2360+ MlasQDQTransposeBlockwiseQuantized<float , 8 , true >(
2361+ const uint8_t * src_weights,
2362+ const float * src_scales,
2363+ const uint8_t * src_zero_points,
2364+ uint8_t * dst_weights,
2365+ float * dst_scales,
2366+ uint8_t * dst_zero_points,
2367+ bool columnwise,
2368+ int rows,
2369+ int columns,
2370+ int quant_block_size,
2371+ MLAS_THREADPOOL* thread_pool
2372+ );
2373+
2374+ template void
2375+ MlasQDQTransposeBlockwiseQuantized<float , 8 , false >(
2376+ const uint8_t * src_weights,
2377+ const float * src_scales,
2378+ const uint8_t * src_zero_points,
2379+ uint8_t * dst_weights,
2380+ float * dst_scales,
2381+ uint8_t * dst_zero_points,
2382+ bool columnwise,
2383+ int rows,
2384+ int columns,
2385+ int quant_block_size,
2386+ MLAS_THREADPOOL* thread_pool
2387+ );
2388+
2389+ template void
2390+ MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 8 , true >(
2391+ const uint8_t * src_weights,
2392+ const MLAS_FP16* src_scales,
2393+ const uint8_t * src_zero_points,
2394+ uint8_t * dst_weights,
2395+ MLAS_FP16* dst_scales,
2396+ uint8_t * dst_zero_points,
2397+ bool columnwise,
2398+ int rows,
2399+ int columns,
2400+ int quant_block_size,
2401+ MLAS_THREADPOOL* thread_pool
2402+ );
2403+
2404+ template void
2405+ MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 8 , false >(
2406+ const uint8_t * src_weights,
2407+ const MLAS_FP16* src_scales,
2408+ const uint8_t * src_zero_points,
2409+ uint8_t * dst_weights,
2410+ MLAS_FP16* dst_scales,
2411+ uint8_t * dst_zero_points,
2412+ bool columnwise,
2413+ int rows,
2414+ int columns,
2415+ int quant_block_size,
2416+ MLAS_THREADPOOL* thread_pool
2417+ );
2418+
2419+ template void
2420+ MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 2 , true >(
2421+ const uint8_t * src_weights,
2422+ const MLAS_FP16* src_scales,
2423+ const uint8_t * src_zero_points,
2424+ uint8_t * dst_weights,
2425+ MLAS_FP16* dst_scales,
2426+ uint8_t * dst_zero_points,
2427+ bool columnwise,
2428+ int rows,
2429+ int columns,
2430+ int quant_block_size,
2431+ MLAS_THREADPOOL* thread_pool
2432+ );
2433+
2434+ template void
2435+ MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 2 , false >(
2436+ const uint8_t * src_weights,
2437+ const MLAS_FP16* src_scales,
2438+ const uint8_t * src_zero_points,
2439+ uint8_t * dst_weights,
2440+ MLAS_FP16* dst_scales,
2441+ uint8_t * dst_zero_points,
2442+ bool columnwise,
2443+ int rows,
2444+ int columns,
2445+ int quant_block_size,
2446+ MLAS_THREADPOOL* thread_pool
2447+ );
0 commit comments