@@ -40,6 +40,19 @@ def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
4040 return None
4141
4242
43+ def _merge_qparams (qspec_1 : QuantArgs , qspec_2 : QuantArgs ) -> QuantArgs :
44+ """Merge two QuantArgs when inputs are quantized differently.
45+
46+ Requires same dtype; picks the first's parameters by default.
47+
48+ """
49+ if qspec_1 .dtype != qspec_2 .dtype :
50+ raise RuntimeError (
51+ f"Cannot merge qparams of different dtypes: { qspec_1 .dtype } vs { qspec_2 .dtype } "
52+ )
53+ return qspec_1
54+
55+
4356def get_input_qparams (node : Node ) -> dict [int , QuantArgs ]:
4457 """Get the input quantization parameters from a node, set by the
4558 'FoldAndAnnotateQParamsPass'.
@@ -121,57 +134,72 @@ def __init__(
121134 super ().__init__ (* args , ** kwargs )
122135 self .exported_program = exported_program
123136
124- def fold_and_annotate_arg (
125- self , graph_module : GraphModule , node : Node , arg_list : list [Node ], i : int
126- ) -> None :
127- input_qparams = None
128- nodes_to_remove = set ()
137+ def _extract_input_params (
138+ self , arg_list : list [Node ]
139+ ) -> tuple [ Optional [ QuantArgs ], set [ Node ]] :
140+ input_qparams : Optional [ QuantArgs ] = None
141+ nodes_to_remove : set [ Node ] = set ()
129142 for arg in arg_list :
130143 if not isinstance (arg , Node ):
131- return
132-
133- arg_quant_params = None
144+ return None , set ()
145+ arg_quant : Optional [QuantArgs ] = None
134146 if arg .target in DQ_OPS :
135147 args = arg .args
136148 scales = args [1 ]
137149 if (
138- isinstance (args [ 1 ] , Node )
150+ isinstance (scales , Node )
139151 and self .exported_program is not None
140- and is_param_node (self .exported_program , args [ 1 ] )
152+ and is_param_node (self .exported_program , scales )
141153 ):
142- scales = get_param_tensor (self .exported_program , args [ 1 ] )
154+ scales = get_param_tensor (self .exported_program , scales )
143155 zps = args [2 ]
144156 if (
145- isinstance (args [ 2 ] , Node )
157+ isinstance (zps , Node )
146158 and self .exported_program is not None
147- and is_param_node (self .exported_program , args [ 2 ] )
159+ and is_param_node (self .exported_program , zps )
148160 ):
149- zps = get_param_tensor (self .exported_program , args [ 2 ] )
150- arg_quant_params = QuantArgs .from_operator (
161+ zps = get_param_tensor (self .exported_program , zps )
162+ arg_quant = QuantArgs .from_operator (
151163 arg .target , (args [0 ], scales , zps , * args [3 :])
152164 )
153- # add arg to nodes_to_remove to fold the dq-node
154165 nodes_to_remove .add (arg )
155- if input_qparams is not None and input_qparams != arg_quant_params :
156- # Two args are quantized differently
157- raise RuntimeError ("Input qparams do not match" )
158- input_qparams = arg_quant_params
159- if input_qparams is not None :
160- node .meta ["input_qparams" ][i ] = input_qparams
161- for n in nodes_to_remove :
162- if n .target not in DQ_OPS :
163- raise RuntimeError (
164- f"Expected one of { DQ_OPS } dq_op, got { n .target } "
165- )
166+ if arg_quant is not None :
167+ if input_qparams is None :
168+ input_qparams = arg_quant
169+ elif input_qparams != arg_quant :
170+ input_qparams = _merge_qparams (input_qparams , arg_quant )
171+ return input_qparams , nodes_to_remove
172+
173+ def _annotate_input_params (
174+ self ,
175+ graph_module : GraphModule ,
176+ node : Node ,
177+ index : int ,
178+ input_qparams : QuantArgs ,
179+ nodes_to_remove : set [Node ],
180+ ) -> None :
181+ node .meta ["input_qparams" ][index ] = input_qparams
182+
183+ for dq in nodes_to_remove :
184+ if dq .target not in DQ_OPS :
185+ raise RuntimeError (f"Expected one of { DQ_OPS } dq_op, got { dq .target } " )
186+ node .replace_input_with (dq , cast (Node , dq .args [0 ]))
187+ if not dq .users :
188+ graph_module .graph .erase_node (dq )
189+
190+ special = _get_special_dtype (input_qparams )
191+ if special :
192+ node .all_input_nodes [index ].meta [TosaSpecialDtype .meta_key ()] = special
166193
167- node .replace_input_with (n , cast (Node , n .args [0 ]))
168- if len (n .users ) == 0 :
169- graph_module .graph .erase_node (n )
170- special_dtype = _get_special_dtype (input_qparams )
171- if special_dtype :
172- node .all_input_nodes [i ].meta [
173- TosaSpecialDtype .meta_key ()
174- ] = special_dtype
194+ def fold_and_annotate_arg (
195+ self , graph_module : GraphModule , node : Node , arg_list : list [Node ], i : int
196+ ) -> None :
197+ input_qparams , nodes_to_remove = self ._extract_input_params (arg_list )
198+ if input_qparams is None :
199+ return
200+ self ._annotate_input_params (
201+ graph_module , node , i , input_qparams , nodes_to_remove
202+ )
175203
176204 def _handle_control_flow_node (self , node : Node , graph_module : GraphModule ):
177205 """Fold outmost quant nodes inside submodule.
0 commit comments