@@ -135,7 +135,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
135135 // which make check (req[out_index] == kWriteInplace) useless.
136136 auto in_mkl_mem = static_cast <const mkldnn::memory*>(in_data[idx.sum ].GetMKLDNNData ());
137137 auto out_mkl_mem = static_cast <const mkldnn::memory*>(out_data[out_index].GetMKLDNNData ());
138- if (in_mkl_mem->get_data_handle () == out_mkl_mem->get_data_handle ()) {
138+ if (in_mkl_mem->get_data_handle () == out_mkl_mem->get_data_handle ()
139+ && in_data[idx.sum ].dtype () == out_data[out_index].dtype ()) {
139140 inplace_ = true ;
140141 }
141142 }
@@ -146,8 +147,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
146147 auto in_mkl_mem = static_cast <const mkldnn::memory*>(in_data[idx.sum ].GetMKLDNNData ());
147148 auto out_mkl_mem = static_cast <const mkldnn::memory*>(out_data[out_index].GetMKLDNNData ());
148149 if (out_data[out_index].dtype () == mshadow::kInt32 ) {
149- auto mem_desc = in_mkl_mem->get_desc ();
150- auto this_dtype = get_mkldnn_type (mshadow::kInt32 );
150+ auto mem_desc = in_mkl_mem->get_desc ();
151+ auto this_dtype = get_mkldnn_type (mshadow::kInt32 );
151152 mem_desc.data .data_type = static_cast <mkldnn_data_type_t >(this_dtype);
152153 mkldnn_mem_ptr tmp_mem (new mkldnn::memory (
153154 mem_desc, CpuEngine::Get ()->get_engine (), out_mkl_mem->get_data_handle ()));
@@ -156,6 +157,27 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
156157 mkldnn::reorder (*in_mkl_mem, *tmp_mem),
157158 {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
158159 output = NDArray (tmp_mem);
160+ } else if (in_data[idx.sum ].dtype () == mshadow::kUint8 &&
161+ out_data[out_index].dtype () == mshadow::kInt8 ) {
162+ auto sum_mem_desc = in_mkl_mem->get_desc ();
163+ auto out_dtype = get_mkldnn_type (mshadow::kInt8 );
164+ sum_mem_desc.data .data_type =
165+ static_cast <mkldnn_data_type_t >(out_dtype);
166+ mkldnn_mem_ptr tmp_mem (
167+ new mkldnn::memory (sum_mem_desc, CpuEngine::Get ()->get_engine (),
168+ out_mkl_mem->get_data_handle ()));
169+ MKLDNNStream::Get ()->RegisterMem (tmp_mem);
170+ const float u8_reorder_scale = 0.5 ;
171+ std::vector<float > reorder_scale = {u8_reorder_scale};
172+ mkldnn::primitive_attr reorder_attr;
173+ reorder_attr.set_output_scales (0 , reorder_scale);
174+ const auto reorder_pd = mkldnn::reorder::primitive_desc (
175+ CpuEngine::Get ()->get_engine (), in_mkl_mem->get_desc (),
176+ CpuEngine::Get ()->get_engine (), sum_mem_desc, reorder_attr);
177+ MKLDNNStream::Get ()->RegisterPrimArgs (
178+ mkldnn::reorder (reorder_pd),
179+ {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
180+ output = NDArray (tmp_mem);
159181 } else {
160182 mkldnn_mem_ptr tmp_mem (new mkldnn::memory (in_mkl_mem->get_desc (),
161183 CpuEngine::Get ()->get_engine (),
@@ -393,6 +415,12 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
393415 float sum_in_scale =
394416 GetQuantizeScale (in_data[idx.sum ].dtype (), cached_sum_min_, cached_sum_max_);
395417 mkldnn_param.sum_scale = out_scale / sum_in_scale;
418+ if (in_data[idx.sum ].dtype () == mshadow::kUint8 &&
419+ out_data[out_index].dtype () == mshadow::kInt8 ) {
420+ // In this case, reorder with scale 0.5 is used on in_data[idx.sum] to
421+ // scale it to s8 range, so sum_scale has to be rescaled as well
422+ mkldnn_param.sum_scale *= 2.0 ;
423+ }
396424 }
397425 } // if (mkldnn_param.quantized)
398426
@@ -659,7 +687,8 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs& attrs,
659687 } else {
660688 if (full_param.mkldnn_param .min_calib_range .has_value () &&
661689 full_param.mkldnn_param .max_calib_range .has_value ()) {
662- if (IsOutputUint8 (full_param)) {
690+ if (IsOutputUint8 (full_param) &&
691+ (!idx.IsSumExist () || in_types->at (idx.sum ) == mshadow::kUint8 )) {
663692 TYPE_ASSIGN_CHECK (*out_types, 0 , mshadow::kUint8 );
664693 } else {
665694 TYPE_ASSIGN_CHECK (*out_types, 0 , mshadow::kInt8 );
0 commit comments