Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 702e475

Browse files
author
DominikaJedynak
authored
[v1.x] Fix for fc with sum when types are incompatible (#21042)
* Type sum fix * Incompatible fc and sum type fix * Clang formatting
1 parent 1eeda33 commit 702e475

1 file changed

Lines changed: 33 additions & 4 deletions

File tree

src/operator/subgraph/mkldnn/mkldnn_fc.cc

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
135135
// which make check (req[out_index] == kWriteInplace) useless.
136136
auto in_mkl_mem = static_cast<const mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
137137
auto out_mkl_mem = static_cast<const mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
138-
if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
138+
if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()
139+
&& in_data[idx.sum].dtype() == out_data[out_index].dtype()) {
139140
inplace_ = true;
140141
}
141142
}
@@ -146,8 +147,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
146147
auto in_mkl_mem = static_cast<const mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
147148
auto out_mkl_mem = static_cast<const mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
148149
if (out_data[out_index].dtype() == mshadow::kInt32) {
149-
auto mem_desc = in_mkl_mem->get_desc();
150-
auto this_dtype = get_mkldnn_type(mshadow::kInt32);
150+
auto mem_desc = in_mkl_mem->get_desc();
151+
auto this_dtype = get_mkldnn_type(mshadow::kInt32);
151152
mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
152153
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(
153154
mem_desc, CpuEngine::Get()->get_engine(), out_mkl_mem->get_data_handle()));
@@ -156,6 +157,27 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
156157
mkldnn::reorder(*in_mkl_mem, *tmp_mem),
157158
{{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
158159
output = NDArray(tmp_mem);
160+
} else if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
161+
out_data[out_index].dtype() == mshadow::kInt8) {
162+
auto sum_mem_desc = in_mkl_mem->get_desc();
163+
auto out_dtype = get_mkldnn_type(mshadow::kInt8);
164+
sum_mem_desc.data.data_type =
165+
static_cast<mkldnn_data_type_t>(out_dtype);
166+
mkldnn_mem_ptr tmp_mem(
167+
new mkldnn::memory(sum_mem_desc, CpuEngine::Get()->get_engine(),
168+
out_mkl_mem->get_data_handle()));
169+
MKLDNNStream::Get()->RegisterMem(tmp_mem);
170+
const float u8_reorder_scale = 0.5;
171+
std::vector<float> reorder_scale = {u8_reorder_scale};
172+
mkldnn::primitive_attr reorder_attr;
173+
reorder_attr.set_output_scales(0, reorder_scale);
174+
const auto reorder_pd = mkldnn::reorder::primitive_desc(
175+
CpuEngine::Get()->get_engine(), in_mkl_mem->get_desc(),
176+
CpuEngine::Get()->get_engine(), sum_mem_desc, reorder_attr);
177+
MKLDNNStream::Get()->RegisterPrimArgs(
178+
mkldnn::reorder(reorder_pd),
179+
{{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
180+
output = NDArray(tmp_mem);
159181
} else {
160182
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(in_mkl_mem->get_desc(),
161183
CpuEngine::Get()->get_engine(),
@@ -393,6 +415,12 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
393415
float sum_in_scale =
394416
GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, cached_sum_max_);
395417
mkldnn_param.sum_scale = out_scale / sum_in_scale;
418+
if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
419+
out_data[out_index].dtype() == mshadow::kInt8) {
420+
// In this case, reorder with scale 0.5 is used on in_data[idx.sum] to
421+
// scale it to s8 range, so sum_scale has to be rescaled as well
422+
mkldnn_param.sum_scale *= 2.0;
423+
}
396424
}
397425
} // if (mkldnn_param.quantized)
398426

@@ -659,7 +687,8 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs& attrs,
659687
} else {
660688
if (full_param.mkldnn_param.min_calib_range.has_value() &&
661689
full_param.mkldnn_param.max_calib_range.has_value()) {
662-
if (IsOutputUint8(full_param)) {
690+
if (IsOutputUint8(full_param) &&
691+
(!idx.IsSumExist() || in_types->at(idx.sum) == mshadow::kUint8)) {
663692
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
664693
} else {
665694
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);

0 commit comments

Comments
 (0)