Skip to content

Commit ce91376

Browse files
authored
Modify scale & offset of WhereDummyDq (#27109)
### Description <!-- Describe your changes. --> - Update `WhereDummyDq` QDQ transformer to be more selective before inserting a dummy `DequantizeLinear` around `Where`. - `SatisfyCondition` now requires the `Where` output to have exactly one consumer and that consumer must be `QuantizeLinear` (Q). Otherwise, the transform is skipped. - `InsertDummyDQ` additionally checks element type consistency between the upstream DQ input tensor type and the downstream Q output tensor type; if they differ, the transform returns without modifying the graph. - Update the implementation of `WhereDummyDq` to avoid negative or zero `scale` value. The change maps the float value to the **boundary** of integer domain to ensure the `scale` value is positive. - If `WhereOp` get a float scalar `xf` and a `DequantizeLinear` as its two inputs, `WhereDummyDq` insert DQ to ensure `xf = DQ(xq, scale, zp)` - The `xq`, `scale` and `zp` are determined with the following table. | | uint8 | uint16 | int8 | int16 | |-----------------|--------------|---------------|-------------|---------------| | xf > 0 | | | | | | xq | 255 | 65535 | 127 | 32767 | | zp | 127 | 32767 | 0 | 0 | | xf < 0 | | | | | | xq | 0 | 0 | -128 | -32768 | | zp | 127 | 32767 | 0 | 0 | | xf = 0 | | | | | | xq | 127 | 32767 | 0 | 0 | | zp | 127 | 32767 | 0 | 0 | - `scale = xf / (xq - zp)` if `xq != zp` else `1` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> - Negative or zero scale value is not friendly for various EP and backend such as QNN-EP. - Inserting an additional DQ is only useful when it forms a valid QDQ “node unit” pattern. If the `Where` output is not followed by a single `QuantizeLinear` (e.g., multiple consumers or a non-Q consumer), adding a dummy DQ cannot create the intended pattern and may lead to non-fusible/undesired graph structures.
1 parent 7afe4c2 commit ce91376

3 files changed

Lines changed: 350 additions & 74 deletions

File tree

onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc

Lines changed: 126 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,112 +14,180 @@
1414

1515
namespace onnxruntime {
1616
bool WhereDummyDq::SatisfyCondition(const Graph& graph, const Node& node) const {
17+
// This transformer targets a very specific pattern around `Where` when used in a QDQ graph:
18+
// cond, DQ(xq), const_scalar -> Where -> Q(yq)
19+
// or
20+
// cond, const_scalar, DQ(xq) -> Where -> Q(yq)
21+
//
22+
// When one `Where` branch is a scalar initializer (no producer node), WhereNodeGroupSelector
23+
// requires both data branches to be produced by DQ nodes so the `Where` can be grouped into a
24+
// single node-unit. We insert a "dummy" DQ for the scalar branch to satisfy that requirement.
1725
if (!(node.OpType() == "Where")) {
1826
return false;
1927
}
28+
29+
// ONNX Where inputs: [0]=condition, [1]=X, [2]=Y
2030
const auto& where_inputs = node.InputDefs();
31+
const auto& where_outputs = node.OutputDefs();
32+
2133
const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name());
2234
const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name());
2335

24-
bool is_p1_dq = (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName);
25-
bool is_p2_dq = (parent_node_2 && parent_node_2->OpType() == QDQ::DQOpName);
36+
// Only apply when the `Where` output is immediately consumed by a single QuantizeLinear.
37+
// If there are multiple consumers (or not a Q), inserting an extra DQ would not help form a
38+
// clean QDQ node-unit and may create additional overhead.
39+
std::vector<const Node*> child_nodes = graph.GetConsumerNodes(where_outputs[0]->Name());
40+
if (child_nodes.size() != 1 || child_nodes[0]->OpType() != QDQ::QOpName) {
41+
return false;
42+
}
2643

27-
// WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input
28-
if (is_p1_dq && !parent_node_2) {
29-
return (where_inputs[2]->Shape()->dim_size() == 0);
44+
const bool is_p1_dq = (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName);
45+
const bool is_p2_dq = (parent_node_2 && parent_node_2->OpType() == QDQ::DQOpName);
46+
47+
// We require exactly one branch to be fed by a DQ and the other branch to be a scalar initializer
48+
// (represented as a NodeArg with rank 0 shape and no producer node).
49+
if (is_p1_dq && graph_utils::IsConstantInitializer(graph, where_inputs[2]->Name(), true)) {
50+
return where_inputs[2]->HasTensorOrScalarShape() ? (where_inputs[2]->Shape()->dim_size() == 0) : false;
3051
}
31-
if (!parent_node_1 && is_p2_dq) {
32-
return (where_inputs[1]->Shape()->dim_size() == 0);
52+
if (graph_utils::IsConstantInitializer(graph, where_inputs[1]->Name(), true) && is_p2_dq) {
53+
return where_inputs[1]->HasTensorOrScalarShape() ? (where_inputs[1]->Shape()->dim_size() == 0) : false;
3354
}
55+
3456
return false;
3557
}
3658

3759
Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const {
60+
// Inserts a DeQuantizeLinear node on the scalar initializer branch of `Where` so that both
61+
// data branches (X and Y) are produced by DQ nodes, enabling downstream QDQ grouping.
3862
const auto& where_inputs = node.InputDefs();
63+
const auto& where_outputs = node.OutputDefs();
3964
const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name());
4065
const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name());
66+
const Node* child_node = graph.GetConsumerNodes(where_outputs[0]->Name())[0];
4167

42-
// With SatisfyCondition, we must have one DQ and one initializer
68+
// From SatisfyCondition():
69+
// - exactly one of parent_node_1/parent_node_2 is a DQ node
70+
// - the other input is a scalar initializer (rank-0 tensor) with no producer node
4371
const Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2;
44-
int const_idx = parent_node_1 ? 2 : 1;
72+
const int const_idx = parent_node_1 ? 2 : 1;
73+
74+
// Guardrail: only insert dummy DQ when the quantized dtype matches the output Q's dtype.
75+
// If they differ, we cannot safely synthesize quantization parameters.
76+
const int32_t dt_input = dq_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
77+
const int32_t dt_output = child_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
78+
if (dt_input != dt_output) {
79+
LOGS(logger, WARNING) << "WhereDummyDq: skip inserting dummy DQ due to mismatched quantized dtype between input DQ "
80+
"and output Q. DQ input dtype="
81+
<< dt_input << ", Q output dtype=" << dt_output;
82+
return Status::OK();
83+
}
4584

4685
const ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr;
47-
graph.GetInitializedTensor(dq_node->InputDefs()[1]->Name(), dq_node_scale_proto);
86+
if (!graph.GetInitializedTensor(dq_node->InputDefs()[1]->Name(), dq_node_scale_proto) ||
87+
dq_node_scale_proto == nullptr) {
88+
LOGS(logger, WARNING) << "WhereDummyDq expects dq branch to have an initializer scale. "
89+
<< "DQ: " << dq_node->Name();
90+
return Status::OK();
91+
};
92+
4893
const ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr;
49-
graph.GetInitializedTensor(dq_node->InputDefs()[2]->Name(), dq_node_zp_proto);
94+
if (!graph.GetInitializedTensor(dq_node->InputDefs()[2]->Name(), dq_node_zp_proto) ||
95+
dq_node_zp_proto == nullptr) {
96+
LOGS(logger, WARNING) << "WhereDummyDq expects dq branch to have an initializer zero point. "
97+
<< "DQ: " << dq_node->Name();
98+
return Status::OK();
99+
};
50100

101+
// Create initializers for the dummy DQ input triplet: (xq, scale, zero_point).
102+
// We choose values so that DeQuantizeLinear(dummy_xq, dummy_scale, dummy_zp) reconstructs
103+
// the original scalar float value as closely as possible.
104+
//
105+
// Note: We only support float scalar constants currently.
51106
// Dummy data initializer.
52-
ONNX_NAMESPACE::TensorProto dummy_data_proto;
53-
dummy_data_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_data"));
107+
ONNX_NAMESPACE::TensorProto dummy_xq_proto;
108+
dummy_xq_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_xq"));
54109
// Set data type to dq node's zp dtype
55-
dummy_data_proto.set_data_type(dq_node_zp_proto->data_type());
110+
dummy_xq_proto.set_data_type(dq_node_zp_proto->data_type());
56111

57112
// Dummy zero point initializer.
58113
ONNX_NAMESPACE::TensorProto dummy_zp_proto;
59114
dummy_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_zp"));
60115
dummy_zp_proto.set_data_type(dq_node_zp_proto->data_type());
61116

117+
// Dummy scale initializer.
118+
ONNX_NAMESPACE::TensorProto dummy_scale_proto;
119+
dummy_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_scale"));
120+
dummy_scale_proto.set_data_type(dq_node_scale_proto->data_type());
121+
122+
// Get original float input
123+
const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr;
124+
graph.GetInitializedTensor(where_inputs[const_idx]->Name(), const_node_data_proto);
125+
Initializer initializer(graph, *const_node_data_proto, graph.ModelPath());
126+
if (dq_node_scale_proto->data_type() != const_node_data_proto->data_type()) {
127+
// WhereDummyDq fills the const value to the dummy DQ's scale
128+
LOGS(logger, WARNING) << "Currently only support existing DQ's scale with same datatype as scalar. "
129+
<< "DQ: " << dq_node->Name() << ", scalar(const): " << where_inputs[const_idx]->Name();
130+
return Status::OK();
131+
}
132+
float dummy_xf = 0;
133+
switch (initializer.data_type()) {
134+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
135+
dummy_xf = *initializer.data<float>();
136+
break;
137+
}
138+
default:
139+
LOGS(logger, WARNING) << "Unsupported dtype of constant input. "
140+
<< "DQ: " << dq_node->Name() << ", scalar(const): " << where_inputs[const_idx]->Name();
141+
return Status::OK();
142+
}
143+
144+
// TensorProto stores INT8/UINT8/INT16/UINT16 values via `int32_data`.
145+
// Keep values in-range for unsigned cases (0..255 / 0..65535) before writing.
146+
int32_t dummy_zp_i32 = 0;
147+
int32_t dummy_xq_i32 = 0;
148+
float dummy_scale = 1.0f;
149+
62150
switch (dummy_zp_proto.data_type()) {
63151
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
64-
int8_t zp = 0;
65-
int8_t dummy_data = 1;
66-
utils::SetRawDataInTensorProto(dummy_zp_proto, &zp, 1);
67-
utils::SetRawDataInTensorProto(dummy_data_proto, &dummy_data, 1);
152+
dummy_zp_i32 = 0;
153+
dummy_xq_i32 = (dummy_xf > 0) ? 127 : ((dummy_xf == 0) ? dummy_zp_i32 : -128);
154+
dummy_scale = (dummy_xf == 0) ? 1 : (float)dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
68155
break;
69156
}
70157
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
71-
uint8_t zp = 0;
72-
uint8_t dummy_data = 1;
73-
utils::SetRawDataInTensorProto(dummy_zp_proto, &zp, 1);
74-
utils::SetRawDataInTensorProto(dummy_data_proto, &dummy_data, 1);
158+
dummy_zp_i32 = 127;
159+
dummy_xq_i32 = (dummy_xf > 0) ? 255 : ((dummy_xf == 0) ? dummy_zp_i32 : 0);
160+
dummy_scale = (dummy_xf == 0) ? 1 : (float)dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
75161
break;
76162
}
77163
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
78-
int16_t zp = 0;
79-
int16_t dummy_data = 1;
80-
utils::SetRawDataInTensorProto(dummy_zp_proto, &zp, 2);
81-
utils::SetRawDataInTensorProto(dummy_data_proto, &dummy_data, 2);
164+
dummy_zp_i32 = 0;
165+
dummy_xq_i32 = (dummy_xf > 0) ? 32767 : ((dummy_xf == 0) ? dummy_zp_i32 : -32768);
166+
dummy_scale = (dummy_xf == 0) ? 1 : (float)dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
82167
break;
83168
}
84169
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
85-
uint16_t zp = 0;
86-
uint16_t dummy_data = 1;
87-
utils::SetRawDataInTensorProto(dummy_zp_proto, &zp, 2);
88-
utils::SetRawDataInTensorProto(dummy_data_proto, &dummy_data, 2);
170+
dummy_zp_i32 = 32767;
171+
dummy_xq_i32 = (dummy_xf > 0) ? 65535 : ((dummy_xf == 0) ? dummy_zp_i32 : 0);
172+
dummy_scale = (dummy_xf == 0) ? 1 : (float)dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
89173
break;
90174
}
91175
default:
92-
LOGS(logger, WARNING) << "Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16";
176+
LOGS(logger, WARNING) << "Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16. "
177+
<< "DQ: " << dq_node->Name() << ", scalar(const): " << where_inputs[const_idx]->Name();
93178
return Status::OK();
94179
}
95180

96-
// Set dummy scale to the original value
97-
const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr;
98-
graph.GetInitializedTensor(where_inputs[const_idx]->Name(), const_node_data_proto);
99-
Initializer initializer(graph, *const_node_data_proto, graph.ModelPath());
100-
if (dq_node_scale_proto->data_type() != const_node_data_proto->data_type()) {
101-
// WhereDummyDq fills the const value to the dummy DQ's scale
102-
LOGS(logger, WARNING) << "Currently only support existing DQ's scale with same datatype as scalar";
103-
return Status::OK();
104-
}
105-
106-
// Dummy scale initializer.
107-
ONNX_NAMESPACE::TensorProto dummy_scale_proto;
108-
dummy_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_scale"));
109-
dummy_scale_proto.set_data_type(dq_node_scale_proto->data_type());
110-
switch (initializer.data_type()) {
111-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
112-
float* where_const_scalar = initializer.data<float>();
113-
utils::SetRawDataInTensorProto(dummy_scale_proto, where_const_scalar, sizeof(float));
114-
break;
115-
}
116-
default:
117-
LOGS(logger, WARNING) << "Currently support scalar with FLOAT";
118-
return Status::OK();
119-
}
181+
dummy_zp_proto.add_int32_data(dummy_zp_i32);
182+
dummy_xq_proto.add_int32_data(dummy_xq_i32);
183+
dummy_scale_proto.add_float_data(dummy_scale);
120184

121-
// Start editing the graph
122-
NodeArg& dummy_data_arg = graph_utils::AddInitializerWithOrtValue(graph, dummy_data_proto);
185+
// Start editing the graph:
186+
// - add the initializers
187+
// - add a DeQuantizeLinear node consuming them
188+
// - rewire the scalar branch of `Where` to use the DQ output
189+
// - drop the original scalar initializer if it becomes unused
190+
NodeArg& dummy_xq_arg = graph_utils::AddInitializerWithOrtValue(graph, dummy_xq_proto);
123191
NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithOrtValue(graph, dummy_scale_proto);
124192
NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithOrtValue(graph, dummy_zp_proto);
125193

@@ -132,7 +200,7 @@ Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, con
132200
graph.GenerateNodeArgName(node.Name() + "_dummy_dq"),
133201
QDQ::DQOpName,
134202
"DeQuantizeLinear from WhereDummyDq GraphTransformer",
135-
{&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg},
203+
{&dummy_xq_arg, &dummy_scale_arg, &dummy_zp_arg},
136204
{&dummy_dq_arg},
137205
node,
138206
nullptr,

onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,29 @@ namespace onnxruntime {
1111
@Class WhereDummyDq
1212
1313
Graph transformer that inserts a dummy DQ on Where node's initializer input
14-
to form Node Unit when Where node has one DQ and one scalar initializer input
14+
to form Node Unit when Where node has one DQ and one scalar initializer input.
15+
16+
If `Where` gets a float scalar `xf` and a `DequantizeLinear` as its two data inputs,
17+
`WhereDummyDq` inserts a dummy DQ so that `xf ≈ DQ(xq, scale, zp)`.
18+
19+
The `xq`, `zp` are chosen per the table below (by the dtype of the existing DQ's zero-point),
20+
and `scale` is computed from them.
21+
22+
We select these values in order to keep the `scale` non-negative:
23+
24+
| | uint8 | uint16 | int8 | int16 |
25+
|-----------------|--------|--------|-------|--------|
26+
| xf > 0 | | | | |
27+
| xq | 255 | 65535 | 127 | 32767 |
28+
| zp | 127 | 32767 | 0 | 0 |
29+
| xf < 0 | | | | |
30+
| xq | 0 | 0 | -128 | -32768 |
31+
| zp | 127 | 32767 | 0 | 0 |
32+
| xf = 0 | | | | |
33+
| xq | 127 | 32767 | 0 | 0 |
34+
| zp | 127 | 32767 | 0 | 0 |
35+
36+
scale = xf / (xq - zp) if (xq != zp) else 1
1537
*/
1638
class WhereDummyDq : public GraphTransformer {
1739
public:
@@ -23,4 +45,4 @@ class WhereDummyDq : public GraphTransformer {
2345
bool SatisfyCondition(const Graph& graph, const Node& node) const;
2446
Status InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const;
2547
};
26-
} // namespace onnxruntime
48+
} // namespace onnxruntime

0 commit comments

Comments
 (0)