Skip to content

Commit afaf212

Browse files
alankellyxnnpack-bot
authored andcommitted
Add QD8 F32 QC8W convolution operator
GEMM path is disabled as each gemm microkernels support per row quantization. Convolution requires per batch quantization. New GEMM microkernels with per batch quantization were generated and benchmarked. The largest performance difference found was 4% faster than IGEMM. Using the existing GEMM microkernels with padded quantization parameters was slower than IGEMM. Given the quantity of new code and associated increase in binary size, we decided that QD8_F32_QC8W will always take the IGEMM path. We can re-visit this decision if a use case is found for GEMM dynamically quantized convolution. PiperOrigin-RevId: 573817352
1 parent 93cb2fe commit afaf212

15 files changed

Lines changed: 1937 additions & 105 deletions

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ xnnpack_cc_library(
570570
msvc_copts = xnnpack_msvc_std_copts(),
571571
deps = [
572572
":common",
573+
":math",
573574
":microparams",
574575
],
575576
)

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ IF(XNNPACK_BUILD_LIBRARY)
515515
# Need C_EXTENSIONS to get constants for mmap (MAP_ANONYMOUS).
516516
SET_TARGET_PROPERTIES(memory PROPERTIES C_EXTENSIONS YES)
517517
ADD_LIBRARY(convolution-test-helpers OBJECT test/convolution-test-helpers.cc)
518+
TARGET_INCLUDE_DIRECTORIES(convolution-test-helpers PRIVATE include src)
518519
ADD_LIBRARY(post-operation OBJECT src/operators/post-operation.c)
519520
IF(XNNPACK_LIBRARY_TYPE STREQUAL "default")
520521
ADD_LIBRARY(XNNPACK ${XNNPACK_SRCS})

include/xnnpack.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,6 +3004,18 @@ enum xnn_status xnn_setup_convolution2d_nhwc_f32(
30043004
const float* input,
30053005
float* output);
30063006

3007+
enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w(
3008+
uint32_t input_padding_top, uint32_t input_padding_right,
3009+
uint32_t input_padding_bottom, uint32_t input_padding_left,
3010+
uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
3011+
uint32_t subsampling_width, uint32_t dilation_height,
3012+
uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
3013+
size_t group_output_channels, size_t input_channel_stride,
3014+
size_t output_channel_stride, const float* kernel_scale,
3015+
const int8_t* kernel, const float* bias, float output_min, float output_max,
3016+
uint32_t flags, xnn_code_cache_t code_cache,
3017+
xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
3018+
30073019
enum xnn_status xnn_create_convolution2d_nhwc_qs8(
30083020
uint32_t input_padding_top,
30093021
uint32_t input_padding_right,
@@ -3034,6 +3046,12 @@ enum xnn_status xnn_create_convolution2d_nhwc_qs8(
30343046
xnn_weights_cache_t weights_cache,
30353047
xnn_operator_t* convolution_op_out);
30363048

3049+
enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w(
3050+
xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
3051+
size_t input_width, size_t* workspace_size, size_t* workspace_alignment,
3052+
size_t* output_height_out, size_t* output_width_out,
3053+
pthreadpool_t threadpool);
3054+
30373055
enum xnn_status xnn_reshape_convolution2d_nhwc_qs8(
30383056
xnn_operator_t convolution_op,
30393057
size_t batch_size,
@@ -3045,6 +3063,11 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_qs8(
30453063
size_t* output_width_out,
30463064
pthreadpool_t threadpool);
30473065

3066+
enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f32_qc8w(
3067+
xnn_operator_t convolution_op, void* workspace, const int8_t* input,
3068+
float* output,
3069+
const struct xnn_dynamic_quantization_params* quantization_params);
3070+
30483071
enum xnn_status xnn_setup_convolution2d_nhwc_qs8(
30493072
xnn_operator_t convolution_op,
30503073
void* workspace,

src/enums/operator-type.c

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
#include <xnnpack/operator-type.h>
1515

1616

17-
static const uint16_t offset[148] = {
17+
static const uint16_t offset[149] = {
1818
0, 8, 22, 36, 50, 64, 78, 92, 119, 147, 175, 203, 230, 257, 289, 321, 339, 357, 382, 408, 424, 440, 455, 470, 492,
19-
515, 538, 561, 584, 607, 630, 653, 671, 694, 718, 736, 759, 783, 807, 831, 855, 879, 903, 927, 941, 956, 971, 997,
20-
1023, 1049, 1075, 1107, 1139, 1165, 1192, 1219, 1236, 1253, 1287, 1321, 1335, 1349, 1363, 1379, 1395, 1421, 1447,
21-
1479, 1511, 1548, 1585, 1611, 1643, 1669, 1703, 1737, 1771, 1805, 1839, 1873, 1903, 1933, 1953, 1973, 1994, 2015,
22-
2036, 2057, 2081, 2105, 2128, 2151, 2169, 2187, 2202, 2217, 2235, 2253, 2272, 2291, 2310, 2329, 2346, 2363, 2379,
23-
2395, 2423, 2451, 2479, 2507, 2534, 2561, 2578, 2619, 2660, 2678, 2696, 2714, 2732, 2747, 2763, 2779, 2797, 2815,
24-
2833, 2859, 2886, 2913, 2930, 2947, 2969, 2991, 3020, 3049, 3068, 3087, 3106, 3125, 3140, 3155, 3170, 3185, 3204,
25-
3224, 3244, 3264, 3285, 3306
19+
515, 538, 561, 584, 607, 630, 653, 671, 694, 718, 736, 759, 783, 807, 831, 855, 890, 914, 938, 962, 976, 991, 1006,
20+
1032, 1058, 1084, 1110, 1142, 1174, 1200, 1227, 1254, 1271, 1288, 1322, 1356, 1370, 1384, 1398, 1414, 1430, 1456,
21+
1482, 1514, 1546, 1583, 1620, 1646, 1678, 1704, 1738, 1772, 1806, 1840, 1874, 1908, 1938, 1968, 1988, 2008, 2029,
22+
2050, 2071, 2092, 2116, 2140, 2163, 2186, 2204, 2222, 2237, 2252, 2270, 2288, 2307, 2326, 2345, 2364, 2381, 2398,
23+
2414, 2430, 2458, 2486, 2514, 2542, 2569, 2596, 2613, 2654, 2695, 2713, 2731, 2749, 2767, 2782, 2798, 2814, 2832,
24+
2850, 2868, 2894, 2921, 2948, 2965, 2982, 3004, 3026, 3055, 3084, 3103, 3122, 3141, 3160, 3175, 3190, 3205, 3220,
25+
3239, 3259, 3279, 3299, 3320, 3341
2626
};
2727

2828
static const char data[] =
@@ -66,6 +66,7 @@ static const char data[] =
6666
"Convolution (NCHW, F32)\0"
6767
"Convolution (NHWC, F16)\0"
6868
"Convolution (NHWC, F32)\0"
69+
"Convolution (NHWC, QD8, F32, QC8W)\0"
6970
"Convolution (NHWC, QC8)\0"
7071
"Convolution (NHWC, QS8)\0"
7172
"Convolution (NHWC, QU8)\0"

src/enums/operator-type.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
string: "Convolution (NHWC, F16)"
8686
- name: xnn_operator_type_convolution_nhwc_f32
8787
string: "Convolution (NHWC, F32)"
88+
- name: xnn_operator_type_convolution_nhwc_qd8_f32_qc8w
89+
string: "Convolution (NHWC, QD8, F32, QC8W)"
8890
- name: xnn_operator_type_convolution_nhwc_qc8
8991
string: "Convolution (NHWC, QC8)"
9092
- name: xnn_operator_type_convolution_nhwc_qs8

src/operator-run.c

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,34 @@ void xnn_compute_grouped_batch_igemm(
500500
&context->params);
501501
}
502502

503+
void xnn_compute_grouped_batch_dqigemm(
504+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
505+
size_t batch_index,
506+
size_t group_index,
507+
size_t mr_block_start,
508+
size_t nr_block_start,
509+
size_t mr_block_size,
510+
size_t nr_block_size)
511+
{
512+
const size_t ks = context->ks;
513+
const size_t cm_stride = context->cm_stride;
514+
515+
context->dq_ukernel.function[XNN_UARCH_DEFAULT](
516+
mr_block_size,
517+
nr_block_size,
518+
context->kc,
519+
context->ks_scaled,
520+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
521+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
522+
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
523+
cm_stride,
524+
context->cn_stride,
525+
context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
526+
context->zero,
527+
&context->params,
528+
(const void*) ((uintptr_t) &context->quantization_params[batch_index]));
529+
}
530+
503531
void xnn_compute_grouped_igemm(
504532
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
505533
size_t group_index,
@@ -526,6 +554,33 @@ void xnn_compute_grouped_igemm(
526554
&context->params);
527555
}
528556

557+
void xnn_compute_grouped_dqigemm(
558+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
559+
size_t group_index,
560+
size_t mr_block_start,
561+
size_t nr_block_start,
562+
size_t mr_block_size,
563+
size_t nr_block_size)
564+
{
565+
const size_t ks = context->ks;
566+
const size_t cm_stride = context->cm_stride;
567+
568+
context->dq_ukernel.function[XNN_UARCH_DEFAULT](
569+
mr_block_size,
570+
nr_block_size,
571+
context->kc,
572+
context->ks_scaled,
573+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
574+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
575+
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
576+
cm_stride,
577+
context->cn_stride,
578+
context->a_offset + group_index * context->ga_stride,
579+
context->zero,
580+
&context->params,
581+
(const void*) ((uintptr_t) context->quantization_params));
582+
}
583+
529584
void xnn_compute_batch_igemm(
530585
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
531586
size_t batch_index,
@@ -552,6 +607,33 @@ void xnn_compute_batch_igemm(
552607
&context->params);
553608
}
554609

610+
void xnn_compute_batch_dqigemm(
611+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
612+
size_t batch_index,
613+
size_t mr_block_start,
614+
size_t nr_block_start,
615+
size_t mr_block_size,
616+
size_t nr_block_size)
617+
{
618+
const size_t ks = context->ks;
619+
const size_t cm_stride = context->cm_stride;
620+
621+
context->dq_ukernel.function[XNN_UARCH_DEFAULT](
622+
mr_block_size,
623+
nr_block_size,
624+
context->kc,
625+
context->ks_scaled,
626+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
627+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
628+
(void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
629+
cm_stride,
630+
context->cn_stride,
631+
context->a_offset + batch_index * context->ba_stride,
632+
context->zero,
633+
&context->params,
634+
(const void*) ((uintptr_t) &context->quantization_params[batch_index]));
635+
}
636+
555637
void xnn_compute_igemm(
556638
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
557639
size_t mr_block_start,
@@ -577,6 +659,31 @@ void xnn_compute_igemm(
577659
&context->params);
578660
}
579661

662+
void xnn_compute_dqigemm(
663+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
664+
size_t mr_block_start,
665+
size_t nr_block_start,
666+
size_t mr_block_size,
667+
size_t nr_block_size)
668+
{
669+
const size_t ks = context->ks;
670+
const size_t cm_stride = context->cm_stride;
671+
672+
context->dq_ukernel.function[XNN_UARCH_DEFAULT](
673+
mr_block_size,
674+
nr_block_size,
675+
context->kc,
676+
context->ks_scaled,
677+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
678+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
679+
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
680+
cm_stride,
681+
context->cn_stride,
682+
context->a_offset,
683+
context->zero,
684+
&context->params,
685+
(const void*) ((uintptr_t) &context->quantization_params[/*mr_block_start=*/0]));
686+
}
580687
// `output_tile_start` should be a multiple of igemm.mr (tile size).
581688
void xnn_compute_conv2d_igemm_indirection(
582689
const struct conv2d_igemm_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
@@ -2028,7 +2135,7 @@ void xnn_compute_rope(
20282135
cm_stride,
20292136
context->cn_stride,
20302137
context->fused_params,
2031-
(const void*) ((uintptr_t) &context->quantization_params[mr_block_start]));
2138+
(const void*) ((uintptr_t) &context->quantization_params[mr_block_start]));
20322139
}
20332140

20342141
void xnn_compute_hmp_grouped_batch_igemm(
@@ -2059,6 +2166,35 @@ void xnn_compute_rope(
20592166
&context->params);
20602167
}
20612168

2169+
void xnn_compute_hmp_grouped_batch_dqigemm(
2170+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
2171+
uint32_t uarch_index,
2172+
size_t batch_index,
2173+
size_t group_index,
2174+
size_t mr_block_start,
2175+
size_t nr_block_start,
2176+
size_t mr_block_size,
2177+
size_t nr_block_size)
2178+
{
2179+
const size_t ks = context->ks;
2180+
const size_t cm_stride = context->cm_stride;
2181+
2182+
context->dq_ukernel.function[uarch_index](
2183+
mr_block_size,
2184+
nr_block_size,
2185+
context->kc,
2186+
context->ks_scaled,
2187+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
2188+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
2189+
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
2190+
cm_stride,
2191+
context->cn_stride,
2192+
context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
2193+
context->zero,
2194+
&context->params,
2195+
(const void*) ((uintptr_t) &context->quantization_params[batch_index]));
2196+
}
2197+
20622198
void xnn_compute_hmp_grouped_igemm(
20632199
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
20642200
uint32_t uarch_index,
@@ -2086,6 +2222,34 @@ void xnn_compute_rope(
20862222
&context->params);
20872223
}
20882224

2225+
void xnn_compute_hmp_grouped_dqigemm(
2226+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
2227+
uint32_t uarch_index,
2228+
size_t group_index,
2229+
size_t mr_block_start,
2230+
size_t nr_block_start,
2231+
size_t mr_block_size,
2232+
size_t nr_block_size)
2233+
{
2234+
const size_t ks = context->ks;
2235+
const size_t cm_stride = context->cm_stride;
2236+
2237+
context->dq_ukernel.function[uarch_index](
2238+
mr_block_size,
2239+
nr_block_size,
2240+
context->kc,
2241+
context->ks_scaled,
2242+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
2243+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
2244+
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
2245+
cm_stride,
2246+
context->cn_stride,
2247+
context->a_offset + group_index * context->ga_stride,
2248+
context->zero,
2249+
&context->params,
2250+
(const void*) ((uintptr_t) context->quantization_params));
2251+
}
2252+
20892253
void xnn_compute_batch_hmp_igemm(
20902254
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
20912255
uint32_t uarch_index,
@@ -2113,6 +2277,34 @@ void xnn_compute_rope(
21132277
&context->params);
21142278
}
21152279

2280+
void xnn_compute_batch_hmp_dqigemm(
2281+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
2282+
uint32_t uarch_index,
2283+
size_t batch_index,
2284+
size_t mr_block_start,
2285+
size_t nr_block_start,
2286+
size_t mr_block_size,
2287+
size_t nr_block_size)
2288+
{
2289+
const size_t ks = context->ks;
2290+
const size_t cm_stride = context->cm_stride;
2291+
2292+
context->dq_ukernel.function[uarch_index](
2293+
mr_block_size,
2294+
nr_block_size,
2295+
context->kc,
2296+
context->ks_scaled,
2297+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
2298+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
2299+
(void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
2300+
cm_stride,
2301+
context->cn_stride,
2302+
context->a_offset + batch_index * context->ba_stride,
2303+
context->zero,
2304+
&context->params,
2305+
(const void*) ((uintptr_t) &context->quantization_params[batch_index]));
2306+
}
2307+
21162308
void xnn_compute_hmp_igemm(
21172309
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
21182310
uint32_t uarch_index,
@@ -2139,6 +2331,33 @@ void xnn_compute_rope(
21392331
&context->params);
21402332
}
21412333

2334+
void xnn_compute_hmp_dqigemm(
2335+
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
2336+
uint32_t uarch_index,
2337+
size_t mr_block_start,
2338+
size_t nr_block_start,
2339+
size_t mr_block_size,
2340+
size_t nr_block_size)
2341+
{
2342+
const size_t ks = context->ks;
2343+
const size_t cm_stride = context->cm_stride;
2344+
2345+
context->dq_ukernel.function[uarch_index](
2346+
mr_block_size,
2347+
nr_block_size,
2348+
context->kc,
2349+
context->ks_scaled,
2350+
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
2351+
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
2352+
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
2353+
cm_stride,
2354+
context->cn_stride,
2355+
context->a_offset,
2356+
context->zero,
2357+
&context->params,
2358+
(const void*) ((uintptr_t) context->quantization_params));
2359+
}
2360+
21422361
void xnn_compute_hmp_scaled_dot_product_attention(
21432362
const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
21442363
uint32_t uarch_index,

0 commit comments

Comments
 (0)