1414
1515namespace onnxruntime {
1616bool WhereDummyDq::SatisfyCondition (const Graph& graph, const Node& node) const {
17+ // This transformer targets a very specific pattern around `Where` when used in a QDQ graph:
18+ // cond, DQ(xq), const_scalar -> Where -> Q(yq)
19+ // or
20+ // cond, const_scalar, DQ(xq) -> Where -> Q(yq)
21+ //
22+ // When one `Where` branch is a scalar initializer (no producer node), WhereNodeGroupSelector
23+ // requires both data branches to be produced by DQ nodes so the `Where` can be grouped into a
24+ // single node-unit. We insert a "dummy" DQ for the scalar branch to satisfy that requirement.
1725 if (!(node.OpType () == " Where" )) {
1826 return false ;
1927 }
28+
29+ // ONNX Where inputs: [0]=condition, [1]=X, [2]=Y
2030 const auto & where_inputs = node.InputDefs ();
31+ const auto & where_outputs = node.OutputDefs ();
32+
2133 const Node* parent_node_1 = graph.GetProducerNode (where_inputs[1 ]->Name ());
2234 const Node* parent_node_2 = graph.GetProducerNode (where_inputs[2 ]->Name ());
2335
24- bool is_p1_dq = (parent_node_1 && parent_node_1->OpType () == QDQ::DQOpName);
25- bool is_p2_dq = (parent_node_2 && parent_node_2->OpType () == QDQ::DQOpName);
36+ // Only apply when the `Where` output is immediately consumed by a single QuantizeLinear.
37+ // If there are multiple consumers (or not a Q), inserting an extra DQ would not help form a
38+ // clean QDQ node-unit and may create additional overhead.
39+ std::vector<const Node*> child_nodes = graph.GetConsumerNodes (where_outputs[0 ]->Name ());
40+ if (child_nodes.size () != 1 || child_nodes[0 ]->OpType () != QDQ::QOpName) {
41+ return false ;
42+ }
2643
27- // WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input
28- if (is_p1_dq && !parent_node_2) {
29- return (where_inputs[2 ]->Shape ()->dim_size () == 0 );
44+ const bool is_p1_dq = (parent_node_1 && parent_node_1->OpType () == QDQ::DQOpName);
45+ const bool is_p2_dq = (parent_node_2 && parent_node_2->OpType () == QDQ::DQOpName);
46+
47+ // We require exactly one branch to be fed by a DQ and the other branch to be a scalar initializer
48+ // (represented as a NodeArg with rank 0 shape and no producer node).
49+ if (is_p1_dq && graph_utils::IsConstantInitializer (graph, where_inputs[2 ]->Name (), true )) {
50+ return where_inputs[2 ]->HasTensorOrScalarShape () ? (where_inputs[2 ]->Shape ()->dim_size () == 0 ) : false ;
3051 }
31- if (!parent_node_1 && is_p2_dq) {
32- return ( where_inputs[1 ]->Shape ()->dim_size () == 0 );
52+ if (graph_utils::IsConstantInitializer (graph, where_inputs[ 1 ]-> Name (), true ) && is_p2_dq) {
53+ return where_inputs[ 1 ]-> HasTensorOrScalarShape () ? ( where_inputs[1 ]->Shape ()->dim_size () == 0 ) : false ;
3354 }
55+
3456 return false ;
3557}
3658
3759Status WhereDummyDq::InsertDummyDQ (Node& node, Graph& graph, bool & modified, const logging::Logger& logger) const {
60+ // Inserts a DeQuantizeLinear node on the scalar initializer branch of `Where` so that both
61+ // data branches (X and Y) are produced by DQ nodes, enabling downstream QDQ grouping.
3862 const auto & where_inputs = node.InputDefs ();
63+ const auto & where_outputs = node.OutputDefs ();
3964 const Node* parent_node_1 = graph.GetProducerNode (where_inputs[1 ]->Name ());
4065 const Node* parent_node_2 = graph.GetProducerNode (where_inputs[2 ]->Name ());
66+ const Node* child_node = graph.GetConsumerNodes (where_outputs[0 ]->Name ())[0 ];
4167
42- // With SatisfyCondition, we must have one DQ and one initializer
68+ // From SatisfyCondition():
69+ // - exactly one of parent_node_1/parent_node_2 is a DQ node
70+ // - the other input is a scalar initializer (rank-0 tensor) with no producer node
4371 const Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2;
44- int const_idx = parent_node_1 ? 2 : 1 ;
72+ const int const_idx = parent_node_1 ? 2 : 1 ;
73+
74+ // Guardrail: only insert dummy DQ when the quantized dtype matches the output Q's dtype.
75+ // If they differ, we cannot safely synthesize quantization parameters.
76+ const int32_t dt_input = dq_node->InputDefs ()[0 ]->TypeAsProto ()->tensor_type ().elem_type ();
77+ const int32_t dt_output = child_node->OutputDefs ()[0 ]->TypeAsProto ()->tensor_type ().elem_type ();
78+ if (dt_input != dt_output) {
79+ LOGS (logger, WARNING) << " WhereDummyDq: skip inserting dummy DQ due to mismatched quantized dtype between input DQ "
80+ " and output Q. DQ input dtype="
81+ << dt_input << " , Q output dtype=" << dt_output;
82+ return Status::OK ();
83+ }
4584
4685 const ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr ;
47- graph.GetInitializedTensor (dq_node->InputDefs ()[1 ]->Name (), dq_node_scale_proto);
86+ if (!graph.GetInitializedTensor (dq_node->InputDefs ()[1 ]->Name (), dq_node_scale_proto) ||
87+ dq_node_scale_proto == nullptr ) {
88+ LOGS (logger, WARNING) << " WhereDummyDq expects dq branch to have an initializer scale. "
89+ << " DQ: " << dq_node->Name ();
90+ return Status::OK ();
91+ };
92+
4893 const ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr ;
49- graph.GetInitializedTensor (dq_node->InputDefs ()[2 ]->Name (), dq_node_zp_proto);
94+ if (!graph.GetInitializedTensor (dq_node->InputDefs ()[2 ]->Name (), dq_node_zp_proto) ||
95+ dq_node_zp_proto == nullptr ) {
96+ LOGS (logger, WARNING) << " WhereDummyDq expects dq branch to have an initializer zero point. "
97+ << " DQ: " << dq_node->Name ();
98+ return Status::OK ();
99+ };
50100
101+ // Create initializers for the dummy DQ input triplet: (xq, scale, zero_point).
102+ // We choose values so that DeQuantizeLinear(dummy_xq, dummy_scale, dummy_zp) reconstructs
103+ // the original scalar float value as closely as possible.
104+ //
105+ // Note: We only support float scalar constants currently.
51106 // Dummy data initializer.
52- ONNX_NAMESPACE::TensorProto dummy_data_proto ;
53- dummy_data_proto .set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_data " ));
107+ ONNX_NAMESPACE::TensorProto dummy_xq_proto ;
108+ dummy_xq_proto .set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_xq " ));
54109 // Set data type to dq node's zp dtype
55- dummy_data_proto .set_data_type (dq_node_zp_proto->data_type ());
110+ dummy_xq_proto .set_data_type (dq_node_zp_proto->data_type ());
56111
57112 // Dummy zero point initializer.
58113 ONNX_NAMESPACE::TensorProto dummy_zp_proto;
59114 dummy_zp_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_zp" ));
60115 dummy_zp_proto.set_data_type (dq_node_zp_proto->data_type ());
61116
117+ // Dummy scale initializer.
118+ ONNX_NAMESPACE::TensorProto dummy_scale_proto;
119+ dummy_scale_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_scale" ));
120+ dummy_scale_proto.set_data_type (dq_node_scale_proto->data_type ());
121+
122+ // Get original float input
123+ const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr ;
124+ graph.GetInitializedTensor (where_inputs[const_idx]->Name (), const_node_data_proto);
125+ Initializer initializer (graph, *const_node_data_proto, graph.ModelPath ());
126+ if (dq_node_scale_proto->data_type () != const_node_data_proto->data_type ()) {
127+ // WhereDummyDq fills the const value to the dummy DQ's scale
128+ LOGS (logger, WARNING) << " Currently only support existing DQ's scale with same datatype as scalar. "
129+ << " DQ: " << dq_node->Name () << " , scalar(const): " << where_inputs[const_idx]->Name ();
130+ return Status::OK ();
131+ }
132+ float dummy_xf = 0 ;
133+ switch (initializer.data_type ()) {
134+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
135+ dummy_xf = *initializer.data <float >();
136+ break ;
137+ }
138+ default :
139+ LOGS (logger, WARNING) << " Unsupported dtype of constant input. "
140+ << " DQ: " << dq_node->Name () << " , scalar(const): " << where_inputs[const_idx]->Name ();
141+ return Status::OK ();
142+ }
143+
144+ // TensorProto stores INT8/UINT8/INT16/UINT16 values via `int32_data`.
145+ // Keep values in-range for unsigned cases (0..255 / 0..65535) before writing.
146+ int32_t dummy_zp_i32 = 0 ;
147+ int32_t dummy_xq_i32 = 0 ;
148+ float dummy_scale = 1 .0f ;
149+
62150 switch (dummy_zp_proto.data_type ()) {
63151 case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
64- int8_t zp = 0 ;
65- int8_t dummy_data = 1 ;
66- utils::SetRawDataInTensorProto (dummy_zp_proto, &zp, 1 );
67- utils::SetRawDataInTensorProto (dummy_data_proto, &dummy_data, 1 );
152+ dummy_zp_i32 = 0 ;
153+ dummy_xq_i32 = (dummy_xf > 0 ) ? 127 : ((dummy_xf == 0 ) ? dummy_zp_i32 : -128 );
154+ dummy_scale = (dummy_xf == 0 ) ? 1 : (float )dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
68155 break ;
69156 }
70157 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
71- uint8_t zp = 0 ;
72- uint8_t dummy_data = 1 ;
73- utils::SetRawDataInTensorProto (dummy_zp_proto, &zp, 1 );
74- utils::SetRawDataInTensorProto (dummy_data_proto, &dummy_data, 1 );
158+ dummy_zp_i32 = 127 ;
159+ dummy_xq_i32 = (dummy_xf > 0 ) ? 255 : ((dummy_xf == 0 ) ? dummy_zp_i32 : 0 );
160+ dummy_scale = (dummy_xf == 0 ) ? 1 : (float )dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
75161 break ;
76162 }
77163 case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
78- int16_t zp = 0 ;
79- int16_t dummy_data = 1 ;
80- utils::SetRawDataInTensorProto (dummy_zp_proto, &zp, 2 );
81- utils::SetRawDataInTensorProto (dummy_data_proto, &dummy_data, 2 );
164+ dummy_zp_i32 = 0 ;
165+ dummy_xq_i32 = (dummy_xf > 0 ) ? 32767 : ((dummy_xf == 0 ) ? dummy_zp_i32 : -32768 );
166+ dummy_scale = (dummy_xf == 0 ) ? 1 : (float )dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
82167 break ;
83168 }
84169 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
85- uint16_t zp = 0 ;
86- uint16_t dummy_data = 1 ;
87- utils::SetRawDataInTensorProto (dummy_zp_proto, &zp, 2 );
88- utils::SetRawDataInTensorProto (dummy_data_proto, &dummy_data, 2 );
170+ dummy_zp_i32 = 32767 ;
171+ dummy_xq_i32 = (dummy_xf > 0 ) ? 65535 : ((dummy_xf == 0 ) ? dummy_zp_i32 : 0 );
172+ dummy_scale = (dummy_xf == 0 ) ? 1 : (float )dummy_xf / (dummy_xq_i32 - dummy_zp_i32);
89173 break ;
90174 }
91175 default :
92- LOGS (logger, WARNING) << " Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16" ;
176+ LOGS (logger, WARNING) << " Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16. "
177+ << " DQ: " << dq_node->Name () << " , scalar(const): " << where_inputs[const_idx]->Name ();
93178 return Status::OK ();
94179 }
95180
96- // Set dummy scale to the original value
97- const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr ;
98- graph.GetInitializedTensor (where_inputs[const_idx]->Name (), const_node_data_proto);
99- Initializer initializer (graph, *const_node_data_proto, graph.ModelPath ());
100- if (dq_node_scale_proto->data_type () != const_node_data_proto->data_type ()) {
101- // WhereDummyDq fills the const value to the dummy DQ's scale
102- LOGS (logger, WARNING) << " Currently only support existing DQ's scale with same datatype as scalar" ;
103- return Status::OK ();
104- }
105-
106- // Dummy scale initializer.
107- ONNX_NAMESPACE::TensorProto dummy_scale_proto;
108- dummy_scale_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_scale" ));
109- dummy_scale_proto.set_data_type (dq_node_scale_proto->data_type ());
110- switch (initializer.data_type ()) {
111- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
112- float * where_const_scalar = initializer.data <float >();
113- utils::SetRawDataInTensorProto (dummy_scale_proto, where_const_scalar, sizeof (float ));
114- break ;
115- }
116- default :
117- LOGS (logger, WARNING) << " Currently support scalar with FLOAT" ;
118- return Status::OK ();
119- }
181+ dummy_zp_proto.add_int32_data (dummy_zp_i32);
182+ dummy_xq_proto.add_int32_data (dummy_xq_i32);
183+ dummy_scale_proto.add_float_data (dummy_scale);
120184
121- // Start editing the graph
122- NodeArg& dummy_data_arg = graph_utils::AddInitializerWithOrtValue (graph, dummy_data_proto);
185+ // Start editing the graph:
186+ // - add the initializers
187+ // - add a DeQuantizeLinear node consuming them
188+ // - rewire the scalar branch of `Where` to use the DQ output
189+ // - drop the original scalar initializer if it becomes unused
190+ NodeArg& dummy_xq_arg = graph_utils::AddInitializerWithOrtValue (graph, dummy_xq_proto);
123191 NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithOrtValue (graph, dummy_scale_proto);
124192 NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithOrtValue (graph, dummy_zp_proto);
125193
@@ -132,7 +200,7 @@ Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, con
132200 graph.GenerateNodeArgName (node.Name () + " _dummy_dq" ),
133201 QDQ::DQOpName,
134202 " DeQuantizeLinear from WhereDummyDq GraphTransformer" ,
135- {&dummy_data_arg , &dummy_scale_arg, &dummy_zp_arg},
203+ {&dummy_xq_arg , &dummy_scale_arg, &dummy_zp_arg},
136204 {&dummy_dq_arg},
137205 node,
138206 nullptr ,
0 commit comments