Skip to content

Commit 26f5e78

Browse files
committed
Revert "cpu: x64: jit_uni_int8_conv: handle scales inside the kernel"
This reverts commit 4a44d83.
1 parent 107c977 commit 26f5e78

8 files changed

Lines changed: 212 additions & 355 deletions

src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp

Lines changed: 43 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
318318
const int32_t *p_sum_zp
319319
= (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr;
320320
mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data);
321+
mov(reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]);
321322
if (p_sum_scale && *p_sum_scale != 1.f) {
322323
mov(ptr[rsp + reg_load_data_off], reg_load_data);
323324
mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
@@ -332,6 +333,14 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
332333
}
333334
const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
334335
const int load_size = mask_flag ? get_tail_size() : simd_w;
336+
const auto ptr_scales_offset
337+
= jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load);
338+
if (jcp.with_bias) {
339+
if (jcp.signed_input || jcp.dst_scale || jcp.with_input_zp)
340+
mov(reg_bias_data, ptr[rsp + reg_bias_data_off]);
341+
cvt2ps(jcp.bia_dt, vmm_bias, reg_bias_data,
342+
jcp.typesize_bia * jcp.oc_block * i_load, load_size);
343+
}
335344
if (jcp.signed_input || jcp.with_input_zp) {
336345
mov(reg_comp_data, ptr[rsp + reg_comp_data_off]);
337346
cvt2ps(data_type::s32, vmm_comp, reg_comp_data,
@@ -347,75 +356,12 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
347356
uni_vcvtdq2ps(vmm_zp_comp, vmm_zp_comp);
348357
}
349358

350-
// TODO: scales support is done not in the most optimal way.
351-
// If there're two free Vmm registers, one can be used to store
352-
// scale_adjust value permanently, the second one can re-use data
353-
// from it and multiply by src_scale that can be obtained at the
354-
// point of scales loading. Then it can be used when multiplying
355-
// by wei_scales. And further re-used for dst scales to avoid
356-
// reading from the same address, but reading from the Vmm instead.
357-
// This would save 1st and 3rd sections for every output Vmm.
358-
//
359-
// If only one Vmm is found, it will add scale_adjust overhead per
360-
// src_scale loading, but the second part of the idea holds.
361-
//
362-
// Note: attempts to identify these Vmms were not taken.
363-
364-
// `avx2` is less flexible ISA in terms of tail and broadcast handling.
365-
// Thus, need to save scales values in Vmm registers.
366-
bool is_vmm_scales_set = false;
367-
if (jcp.with_src_scales) {
368-
mov(reg_src_scales, ptr[rsp + reg_src_scales_off]);
369-
uni_vbroadcastss(vmm_scales, ptr[reg_src_scales]);
370-
is_vmm_scales_set = true;
371-
}
372-
if (jcp.with_wei_scales) {
373-
mov(reg_wei_scales, ptr[rsp + reg_wei_scales_off]);
374-
375-
if (!jcp.is_oc_scale) {
376-
uni_vbroadcastss(vmm_scales_tmp, ptr[reg_wei_scales]);
377-
} else {
378-
int scale_offset = jcp.is_oc_scale
379-
* (sizeof(float) * jcp.oc_block * i_load);
380-
if (mask_flag) {
381-
uni_vpxor(
382-
vmm_scales_tmp, vmm_scales_tmp, vmm_scales_tmp);
383-
cvt2ps(data_type::f32, vmm_scales_tmp, reg_wei_scales,
384-
scale_offset, get_tail_size());
385-
} else {
386-
uni_vmovups(vmm_scales_tmp,
387-
ptr[reg_wei_scales + scale_offset]);
388-
}
389-
}
390-
if (is_vmm_scales_set) {
391-
uni_vmulps(vmm_scales, vmm_scales, vmm_scales_tmp);
392-
} else {
393-
uni_vmovups(vmm_scales, vmm_scales_tmp);
394-
}
395-
is_vmm_scales_set = true;
396-
}
397-
if (jcp.wei_adj_scale != 1.f) {
398-
mov(reg_scale_adjust, float2int(1.f / jcp.wei_adj_scale));
399-
auto vmm_scale_adjust = vmm_scales_tmp;
400-
auto xmm_scale_adjust = Xmm(vmm_scale_adjust.getIdx());
401-
uni_vmovq(xmm_scale_adjust, reg_scale_adjust);
402-
uni_vbroadcastss(vmm_scale_adjust, xmm_scale_adjust);
403-
if (is_vmm_scales_set) {
404-
uni_vmulps(vmm_scales, vmm_scales, vmm_scale_adjust);
405-
} else {
406-
uni_vmovups(vmm_scales, vmm_scale_adjust);
407-
}
408-
is_vmm_scales_set = true;
409-
}
410-
411-
// The order of this load is important. `vmm_bias` is used as a
412-
// temporary vector register for scales. Load bias data into it
413-
// after scales are processed.
414-
if (jcp.with_bias) {
415-
if (jcp.signed_input || jcp.with_dst_scales)
416-
mov(reg_bias_data, ptr[rsp + reg_bias_data_off]);
417-
cvt2ps(jcp.bia_dt, vmm_bias, reg_bias_data,
418-
jcp.typesize_bia * jcp.oc_block * i_load, load_size);
359+
if (mask_flag) {
360+
uni_vpxor(vmm_scale, vmm_scale, vmm_scale);
361+
cvt2ps(data_type::f32, vmm_scale, reg_ptr_scales,
362+
ptr_scales_offset, get_tail_size());
363+
} else {
364+
uni_vmovups(vmm_scale, ptr[reg_ptr_scales + ptr_scales_offset]);
419365
}
420366

421367
for (int i_ur = 0; i_ur < ur; ++i_ur) {
@@ -424,23 +370,23 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
424370
if (jcp.signed_input || jcp.with_input_zp) uni_vaddps(r, r, vmm_comp);
425371
if (jcp.src_zero_point) uni_vaddps(r, r, vmm_zp_comp);
426372

427-
if (is_vmm_scales_set) uni_vmulps(r, r, vmm_scales);
373+
uni_vmulps(r, r, vmm_scale);
428374

429375
if (jcp.with_bias) uni_vaddps(r, r, vmm_bias);
430376
}
431377
}
432378

433379
apply_postops(ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp);
434380

435-
if (jcp.with_dst_scales) {
436-
mov(reg_dst_scales, ptr[rsp + reg_dst_scales_off]);
437-
uni_vbroadcastss(vmm_dst_scales, ptr[reg_dst_scales]);
381+
if (jcp.dst_scale) {
382+
mov(reg_ptr_dst_scale, ptr[rsp + reg_dst_scale_off]);
383+
uni_vmovups(vmm_dst_scale, ptr[reg_ptr_dst_scale]);
438384

439385
/* Apply dst scale to accumulator */
440386
for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
441387
for (int i_ur = 0; i_ur < ur; ++i_ur) {
442388
const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
443-
uni_vmulps(r, r, vmm_dst_scales);
389+
uni_vmulps(r, r, vmm_dst_scale);
444390
}
445391
}
446392
}
@@ -604,23 +550,18 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::generate() {
604550
mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
605551
mov(ptr[rsp + reg_src_zero_point_off], reg_src_zero_point);
606552
}
607-
if (jcp.with_src_scales) {
608-
mov(reg_src_scales, ptr[param1 + GET_OFF(src_scales)]);
609-
mov(ptr[rsp + reg_src_scales_off], reg_src_scales);
610-
}
611-
if (jcp.with_wei_scales) {
612-
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales)]);
613-
mov(ptr[rsp + reg_wei_scales_off], reg_wei_scales);
614-
}
615-
if (jcp.with_dst_scales) {
616-
if (!jcp.signed_input && !jcp.with_input_zp) mov(ptr[rsp + reg_bias_data_off], reg_bias_data);
617-
mov(reg_dst_scales, ptr[param1 + GET_OFF(dst_scales)]);
618-
mov(ptr[rsp + reg_dst_scales_off], reg_dst_scales);
553+
if (jcp.dst_scale) {
554+
if (!jcp.signed_input && !jcp.with_input_zp)
555+
mov(ptr[rsp + reg_bias_data_off], reg_bias_data);
556+
mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
557+
mov(ptr[rsp + reg_dst_scale_off], reg_ptr_dst_scale);
619558
}
620559
if (jcp.dst_zero_point) {
621560
mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
622561
mov(ptr[rsp + reg_dst_zero_point_off], reg_dst_zero_point);
623562
}
563+
mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
564+
mov(ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales);
624565
mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
625566
mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
626567
mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
@@ -636,11 +577,11 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::generate() {
636577
bcast_loop(load_loop_blk);
637578
add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
638579
if (jcp.with_bias) {
639-
if (jcp.signed_input || jcp.with_dst_scales || jcp.with_input_zp)
580+
if (jcp.signed_input || jcp.dst_scale || jcp.with_input_zp)
640581
mov(reg_bias_data, ptr[rsp + reg_bias_data_off]);
641582
add(reg_bias_data,
642583
load_loop_blk * jcp.load_block * jcp.typesize_bia);
643-
if (jcp.signed_input || jcp.with_dst_scales || jcp.with_input_zp)
584+
if (jcp.signed_input || jcp.dst_scale || jcp.with_input_zp)
644585
mov(ptr[rsp + reg_bias_data_off], reg_bias_data);
645586
}
646587
if (jcp.signed_input || jcp.with_input_zp) {
@@ -656,13 +597,11 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::generate() {
656597
mov(ptr[rsp + reg_zp_compensation_off], reg_zp_compensation);
657598
}
658599
mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data);
659-
if (jcp.with_wei_scales) {
660-
mov(reg_wei_scales, ptr[rsp + reg_wei_scales_off]);
661-
add(reg_wei_scales,
662-
jcp.is_oc_scale * load_loop_blk * jcp.load_block
663-
* sizeof(float));
664-
mov(ptr[rsp + reg_wei_scales_off], reg_wei_scales);
665-
}
600+
mov(reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]);
601+
add(reg_ptr_scales,
602+
jcp.is_oc_scale * load_loop_blk * jcp.load_block
603+
* sizeof(float));
604+
mov(ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales);
666605
mov(reg_bcast_data, ptr[rsp + reg_bcast_data_off]);
667606
add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out);
668607
sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
@@ -1023,11 +962,10 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel_t<isa>::init_conf(
1023962
// miniumum size of load dim chunk for work distribution within threads
1024963
jcp.nb_load_chunk = 1;
1025964

1026-
jcp.is_oc_scale = attr.scales_.get_mask(DNNL_ARG_WEIGHTS) > 0;
1027-
jcp.with_src_scales = !attr.scales_.get(DNNL_ARG_SRC).has_default_values();
1028-
jcp.with_wei_scales
1029-
= !attr.scales_.get(DNNL_ARG_WEIGHTS).has_default_values();
1030-
jcp.with_dst_scales = !attr.scales_.get(DNNL_ARG_DST).has_default_values();
965+
const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
966+
const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
967+
jcp.is_oc_scale = wei_scales.get_mask() > 0;
968+
jcp.dst_scale = !dst_scales.has_default_values();
1031969

1032970
jcp.wei_adj_scale
1033971
= (weights_d.extra().flags & memory_extra_flags::scale_adjust)
@@ -1043,11 +981,12 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_t<isa>::init_scratchpad(
1043981
const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
1044982
using namespace dnnl::impl::memory_tracking::names;
1045983

1046-
if (jcp.with_dst_scales) {
1047-
// See brgemm_types.hpp comment for `with_dst_scales`.
1048-
scratchpad.book(key_conv_dst_scales,
1049-
static_cast<size_t>(jcp.nthr) * sizeof(float), 4096);
984+
dim_t count = 8;
985+
if (!attr.scales_.has_default_values(DNNL_ARG_WEIGHTS)) {
986+
const int wei_mask = attr.scales_.get_mask(DNNL_ARG_WEIGHTS);
987+
if (wei_mask > 0) count = static_cast<dim_t>(jcp.oc) * jcp.ngroups;
1050988
}
989+
scratchpad.book<float>(key_conv_adjusted_scales, count);
1051990
}
1052991

1053992
template struct jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<avx2, Ymm>;

src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,15 @@ struct jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t : public jit_generator_t {
4848
ker_max_reg_idx = 13,
4949
};
5050
const Xbyak::Reg64 reg_bcast_data = r8;
51+
const Xbyak::Reg64 reg_ptr_scales = r8;
5152
const Xbyak::Reg64 reg_output_data = r9;
52-
const Xbyak::Reg64 reg_src_scales = r8;
53-
const Xbyak::Reg64 reg_wei_scales = r8;
54-
const Xbyak::Reg64 reg_scale_adjust = r8;
55-
const Xbyak::Reg64 reg_dst_scales = r12;
5653
const Xbyak::Reg64 reg_load_data = r10;
5754
const Xbyak::Reg64 reg_ptr_sum_scale = r10;
5855
const Xbyak::Reg64 reg_ptr_sum_zp = rdx;
5956
const Xbyak::Reg64 reg_reduce_loop_work = r11;
6057
const Xbyak::Reg64 reg_bias_data = r12;
6158
const Xbyak::Reg64 reg_comp_data = r12;
59+
const Xbyak::Reg64 reg_ptr_dst_scale = r12;
6260
const Xbyak::Reg64 reg_init_bcast = r13;
6361
const Xbyak::Reg64 reg_store_bcast = r13;
6462
const Xbyak::Reg64 reg_reduce_loop_iter = r13;
@@ -91,9 +89,7 @@ struct jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t : public jit_generator_t {
9189
const Vmm vmm_bcast = Vmm(0);
9290
const Vmm vmm_saturation = Vmm(0);
9391
/* used during scale section of store_output */
94-
const Vmm vmm_scales = Vmm(1);
95-
const Vmm vmm_scales_tmp = Vmm(3); // Has dependency on `vmm_bias`.
96-
const Vmm vmm_dst_scales = Vmm(1);
92+
const Vmm vmm_scale = Vmm(1);
9793
/* used during post_op sum section of store_output */
9894
const Vmm vmm_prev_dst = Vmm(1);
9995
/* used during bias section of store_output */
@@ -102,22 +98,23 @@ struct jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t : public jit_generator_t {
10298
/* zero-point */
10399
const Vmm vmm_zp = Vmm(1);
104100
const Vmm vmm_zp_comp = Vmm(2);
101+
/* dst scale */
102+
const Vmm vmm_dst_scale = Vmm(1);
105103

106104
constexpr static int simd_w = isa == avx2 ? 8 : 4;
107105
constexpr static int reg64_size = sizeof(int64_t);
108106
constexpr static int bcast_loop_work_off = 0;
109107
constexpr static int reg_bias_data_off = 1 * reg64_size;
110108
constexpr static int reg_bcast_data_off = 2 * reg64_size;
111109
constexpr static int reg_load_data_off = 3 * reg64_size;
112-
constexpr static int reg_src_scales_off = 4 * reg64_size;
113-
constexpr static int reg_wei_scales_off = 5 * reg64_size;
114-
constexpr static int reg_dst_scales_off = 6 * reg64_size;
115-
constexpr static int reg_bcast_loop_iter_off = 7 * reg64_size;
116-
constexpr static int reg_comp_data_off = 8 * reg64_size;
117-
constexpr static int reg_zp_compensation_off = 9 * reg64_size;
118-
constexpr static int reg_src_zero_point_off = 10 * reg64_size;
119-
constexpr static int reg_dst_zero_point_off = 11 * reg64_size;
120-
constexpr static int stack_space_needed = 12 * reg64_size;
110+
constexpr static int reg_ptr_sum_scale_off = 4 * reg64_size;
111+
constexpr static int reg_bcast_loop_iter_off = 5 * reg64_size;
112+
constexpr static int reg_comp_data_off = 6 * reg64_size;
113+
constexpr static int reg_zp_compensation_off = 7 * reg64_size;
114+
constexpr static int reg_src_zero_point_off = 8 * reg64_size;
115+
constexpr static int reg_dst_zero_point_off = 9 * reg64_size;
116+
constexpr static int reg_dst_scale_off = 10 * reg64_size;
117+
constexpr static int stack_space_needed = 11 * reg64_size;
121118

122119
int vreg_accum_idx(
123120
const int load_loop_blk, const int i_load, const int i_ur);

0 commit comments

Comments
 (0)