@@ -318,6 +318,7 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
318318 const int32_t *p_sum_zp
319319 = (sum_idx != -1 ) ? &p.entry_ [sum_idx].sum .zero_point : nullptr ;
320320 mov (ptr[rsp + reg_bcast_data_off], reg_bcast_data);
321+ mov (reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]);
321322 if (p_sum_scale && *p_sum_scale != 1 .f ) {
322323 mov (ptr[rsp + reg_load_data_off], reg_load_data);
323324 mov (reg_ptr_sum_scale, reinterpret_cast <size_t >(p_sum_scale));
@@ -332,6 +333,14 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
332333 }
333334 const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1 ;
334335 const int load_size = mask_flag ? get_tail_size () : simd_w;
336+ const auto ptr_scales_offset
337+ = jcp.is_oc_scale * (sizeof (float ) * jcp.oc_block * i_load);
338+ if (jcp.with_bias ) {
339+ if (jcp.signed_input || jcp.dst_scale || jcp.with_input_zp )
340+ mov (reg_bias_data, ptr[rsp + reg_bias_data_off]);
341+ cvt2ps (jcp.bia_dt , vmm_bias, reg_bias_data,
342+ jcp.typesize_bia * jcp.oc_block * i_load, load_size);
343+ }
335344 if (jcp.signed_input || jcp.with_input_zp ) {
336345 mov (reg_comp_data, ptr[rsp + reg_comp_data_off]);
337346 cvt2ps (data_type::s32, vmm_comp, reg_comp_data,
@@ -347,75 +356,12 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
347356 uni_vcvtdq2ps (vmm_zp_comp, vmm_zp_comp);
348357 }
349358
350- // TODO: scales support is done not in the most optimal way.
351- // If there're two free Vmm registers, one can be used to store
352- // scale_adjust value permanently, the second one can re-use data
353- // from it and multiply by src_scale that can be obtained at the
354- // point of scales loading. Then it can be used when multiplying
355- // by wei_scales. And further re-used for dst scales to avoid
356- // reading from the same address, but reading from the Vmm instead.
357- // This would save 1st and 3rd sections for every output Vmm.
358- //
359- // If only one Vmm is found, it will add scale_adjust overhead per
360- // src_scale loading, but the second part of the idea holds.
361- //
362- // Note: attempts to identify these Vmms were not taken.
363-
364- // `avx2` is less flexible ISA in terms of tail and broadcast handling.
365- // Thus, need to save scales values in Vmm registers.
366- bool is_vmm_scales_set = false ;
367- if (jcp.with_src_scales ) {
368- mov (reg_src_scales, ptr[rsp + reg_src_scales_off]);
369- uni_vbroadcastss (vmm_scales, ptr[reg_src_scales]);
370- is_vmm_scales_set = true ;
371- }
372- if (jcp.with_wei_scales ) {
373- mov (reg_wei_scales, ptr[rsp + reg_wei_scales_off]);
374-
375- if (!jcp.is_oc_scale ) {
376- uni_vbroadcastss (vmm_scales_tmp, ptr[reg_wei_scales]);
377- } else {
378- int scale_offset = jcp.is_oc_scale
379- * (sizeof (float ) * jcp.oc_block * i_load);
380- if (mask_flag) {
381- uni_vpxor (
382- vmm_scales_tmp, vmm_scales_tmp, vmm_scales_tmp);
383- cvt2ps (data_type::f32 , vmm_scales_tmp, reg_wei_scales,
384- scale_offset, get_tail_size ());
385- } else {
386- uni_vmovups (vmm_scales_tmp,
387- ptr[reg_wei_scales + scale_offset]);
388- }
389- }
390- if (is_vmm_scales_set) {
391- uni_vmulps (vmm_scales, vmm_scales, vmm_scales_tmp);
392- } else {
393- uni_vmovups (vmm_scales, vmm_scales_tmp);
394- }
395- is_vmm_scales_set = true ;
396- }
397- if (jcp.wei_adj_scale != 1 .f ) {
398- mov (reg_scale_adjust, float2int (1 .f / jcp.wei_adj_scale ));
399- auto vmm_scale_adjust = vmm_scales_tmp;
400- auto xmm_scale_adjust = Xmm (vmm_scale_adjust.getIdx ());
401- uni_vmovq (xmm_scale_adjust, reg_scale_adjust);
402- uni_vbroadcastss (vmm_scale_adjust, xmm_scale_adjust);
403- if (is_vmm_scales_set) {
404- uni_vmulps (vmm_scales, vmm_scales, vmm_scale_adjust);
405- } else {
406- uni_vmovups (vmm_scales, vmm_scale_adjust);
407- }
408- is_vmm_scales_set = true ;
409- }
410-
411- // The order of this load is important. `vmm_bias` is used as a
412- // temporary vector register for scales. Load bias data into it
413- // after scales are processed.
414- if (jcp.with_bias ) {
415- if (jcp.signed_input || jcp.with_dst_scales )
416- mov (reg_bias_data, ptr[rsp + reg_bias_data_off]);
417- cvt2ps (jcp.bia_dt , vmm_bias, reg_bias_data,
418- jcp.typesize_bia * jcp.oc_block * i_load, load_size);
359+ if (mask_flag) {
360+ uni_vpxor (vmm_scale, vmm_scale, vmm_scale);
361+ cvt2ps (data_type::f32 , vmm_scale, reg_ptr_scales,
362+ ptr_scales_offset, get_tail_size ());
363+ } else {
364+ uni_vmovups (vmm_scale, ptr[reg_ptr_scales + ptr_scales_offset]);
419365 }
420366
421367 for (int i_ur = 0 ; i_ur < ur; ++i_ur) {
@@ -424,23 +370,23 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::reduce_loop(
424370 if (jcp.signed_input || jcp.with_input_zp ) uni_vaddps (r, r, vmm_comp);
425371 if (jcp.src_zero_point ) uni_vaddps (r, r, vmm_zp_comp);
426372
427- if (is_vmm_scales_set) uni_vmulps (r, r, vmm_scales );
373+ uni_vmulps (r, r, vmm_scale );
428374
429375 if (jcp.with_bias ) uni_vaddps (r, r, vmm_bias);
430376 }
431377 }
432378
433379 apply_postops (ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp);
434380
435- if (jcp.with_dst_scales ) {
436- mov (reg_dst_scales , ptr[rsp + reg_dst_scales_off ]);
437- uni_vbroadcastss (vmm_dst_scales , ptr[reg_dst_scales ]);
381+ if (jcp.dst_scale ) {
382+ mov (reg_ptr_dst_scale , ptr[rsp + reg_dst_scale_off ]);
383+ uni_vmovups (vmm_dst_scale , ptr[reg_ptr_dst_scale ]);
438384
439385 /* Apply dst scale to accumulator */
440386 for (int i_load = 0 ; i_load < load_loop_blk; ++i_load) {
441387 for (int i_ur = 0 ; i_ur < ur; ++i_ur) {
442388 const auto r = vreg_accum (load_loop_blk, i_load, i_ur);
443- uni_vmulps (r, r, vmm_dst_scales );
389+ uni_vmulps (r, r, vmm_dst_scale );
444390 }
445391 }
446392 }
@@ -604,23 +550,18 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::generate() {
604550 mov (reg_src_zero_point, ptr[param1 + GET_OFF (src_zero_point)]);
605551 mov (ptr[rsp + reg_src_zero_point_off], reg_src_zero_point);
606552 }
607- if (jcp.with_src_scales ) {
608- mov (reg_src_scales, ptr[param1 + GET_OFF (src_scales)]);
609- mov (ptr[rsp + reg_src_scales_off], reg_src_scales);
610- }
611- if (jcp.with_wei_scales ) {
612- mov (reg_wei_scales, ptr[param1 + GET_OFF (wei_scales)]);
613- mov (ptr[rsp + reg_wei_scales_off], reg_wei_scales);
614- }
615- if (jcp.with_dst_scales ) {
616- if (!jcp.signed_input && !jcp.with_input_zp ) mov (ptr[rsp + reg_bias_data_off], reg_bias_data);
617- mov (reg_dst_scales, ptr[param1 + GET_OFF (dst_scales)]);
618- mov (ptr[rsp + reg_dst_scales_off], reg_dst_scales);
553+ if (jcp.dst_scale ) {
554+ if (!jcp.signed_input && !jcp.with_input_zp )
555+ mov (ptr[rsp + reg_bias_data_off], reg_bias_data);
556+ mov (reg_ptr_dst_scale, ptr[param1 + GET_OFF (dst_scale)]);
557+ mov (ptr[rsp + reg_dst_scale_off], reg_ptr_dst_scale);
619558 }
620559 if (jcp.dst_zero_point ) {
621560 mov (reg_dst_zero_point, ptr[param1 + GET_OFF (dst_zero_point)]);
622561 mov (ptr[rsp + reg_dst_zero_point_off], reg_dst_zero_point);
623562 }
563+ mov (reg_ptr_scales, ptr[param1 + GET_OFF (scales)]);
564+ mov (ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales);
624565 mov (reg_bcast_data, ptr[param1 + GET_OFF (bcast_data)]);
625566 mov (reg_load_data, ptr[param1 + GET_OFF (load_data)]);
626567 mov (reg_output_data, ptr[param1 + GET_OFF (output_data)]);
@@ -636,11 +577,11 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::generate() {
636577 bcast_loop (load_loop_blk);
637578 add (reg_load_data, load_loop_blk * jcp.load_loop_load_step );
638579 if (jcp.with_bias ) {
639- if (jcp.signed_input || jcp.with_dst_scales || jcp.with_input_zp )
580+ if (jcp.signed_input || jcp.dst_scale || jcp.with_input_zp )
640581 mov (reg_bias_data, ptr[rsp + reg_bias_data_off]);
641582 add (reg_bias_data,
642583 load_loop_blk * jcp.load_block * jcp.typesize_bia );
643- if (jcp.signed_input || jcp.with_dst_scales || jcp.with_input_zp )
584+ if (jcp.signed_input || jcp.dst_scale || jcp.with_input_zp )
644585 mov (ptr[rsp + reg_bias_data_off], reg_bias_data);
645586 }
646587 if (jcp.signed_input || jcp.with_input_zp ) {
@@ -656,13 +597,11 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t<isa, Vmm>::generate() {
656597 mov (ptr[rsp + reg_zp_compensation_off], reg_zp_compensation);
657598 }
658599 mov (ptr[rsp + reg_bcast_data_off], reg_bcast_data);
659- if (jcp.with_wei_scales ) {
660- mov (reg_wei_scales, ptr[rsp + reg_wei_scales_off]);
661- add (reg_wei_scales,
662- jcp.is_oc_scale * load_loop_blk * jcp.load_block
663- * sizeof (float ));
664- mov (ptr[rsp + reg_wei_scales_off], reg_wei_scales);
665- }
600+ mov (reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]);
601+ add (reg_ptr_scales,
602+ jcp.is_oc_scale * load_loop_blk * jcp.load_block
603+ * sizeof (float ));
604+ mov (ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales);
666605 mov (reg_bcast_data, ptr[rsp + reg_bcast_data_off]);
667606 add (reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out );
668607 sub (reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step );
@@ -1023,11 +962,10 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel_t<isa>::init_conf(
1023962 // miniumum size of load dim chunk for work distribution within threads
1024963 jcp.nb_load_chunk = 1 ;
1025964
1026- jcp.is_oc_scale = attr.scales_ .get_mask (DNNL_ARG_WEIGHTS) > 0 ;
1027- jcp.with_src_scales = !attr.scales_ .get (DNNL_ARG_SRC).has_default_values ();
1028- jcp.with_wei_scales
1029- = !attr.scales_ .get (DNNL_ARG_WEIGHTS).has_default_values ();
1030- jcp.with_dst_scales = !attr.scales_ .get (DNNL_ARG_DST).has_default_values ();
965+ const auto &wei_scales = attr.scales_ .get (DNNL_ARG_WEIGHTS);
966+ const auto &dst_scales = attr.scales_ .get (DNNL_ARG_DST);
967+ jcp.is_oc_scale = wei_scales.get_mask () > 0 ;
968+ jcp.dst_scale = !dst_scales.has_default_values ();
1031969
1032970 jcp.wei_adj_scale
1033971 = (weights_d.extra ().flags & memory_extra_flags::scale_adjust)
@@ -1043,11 +981,12 @@ void jit_uni_x8s8s32x_1x1_conv_kernel_t<isa>::init_scratchpad(
1043981 const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
1044982 using namespace dnnl ::impl::memory_tracking::names;
1045983
1046- if (jcp. with_dst_scales ) {
1047- // See brgemm_types.hpp comment for `with_dst_scales`.
1048- scratchpad. book (key_conv_dst_scales,
1049- static_cast <size_t >(jcp.nthr ) * sizeof ( float ), 4096 ) ;
984+ dim_t count = 8 ;
985+ if (!attr. scales_ . has_default_values (DNNL_ARG_WEIGHTS)) {
986+ const int wei_mask = attr. scales_ . get_mask (DNNL_ARG_WEIGHTS);
987+ if (wei_mask > 0 ) count = static_cast <dim_t >(jcp.oc ) * jcp. ngroups ;
1050988 }
989+ scratchpad.book <float >(key_conv_adjusted_scales, count);
1051990}
1052991
1053992template struct jit_uni_x8s8s32x_1x1_conv_kernel_vmm_t <avx2, Ymm>;
0 commit comments