@@ -248,12 +248,22 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
248248@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
249249class AdvanceQuantizeOpAboveDefChainPass (ExportPass ):
250250 """
251- If the input to quantize op is linear chain of view, transpose, permute, or
252- slice ops that are trivially quantized, we can convert the pattern
253- view/transpose/permute/slice(fp32) -> quantize(int8/uint8) to
254- quantize(int8/uint8) -> view/transpose/permute/slice(int8/uint8).
255- The benefit of such reordering is that the view/transpose/permute/slice
256- will move far less data.
251+ Advances a quantize op above data-movement ops to reduce data volume.
252+
253+ Handles two cases:
254+
255+ 1. Linear chain: if the input to a quantize op is a chain of trivially
256+ quantizable ops (view, transpose, permute, slice), rewrite
257+ data_movement(fp32) -> quantize to quantize -> data_movement(quantized)
258+ so the data movement operates on smaller quantized tensors.
259+
260+ 2. Cat: if the input to a quantize op is a cat with a single user (the
261+ quantize), advance the quantize above the cat by quantizing each cat
262+ input individually. A later pass can clean up any redundant
263+ dequant-quant pairs on the inputs.
264+
265+ For the cat case, SplitDequantizedCatPass should run first to ensure
266+ each cat has at most one quantize consumer.
257267 """
258268
259269 def __init__ (self ):
@@ -302,6 +312,47 @@ def advancing_feasible(self, quant_node: torch.fx.Node):
302312 # All the conditions satisfied, we advance.
303313 return True
304314
315+ def _advance_above_cat (
316+ self , quant_node : torch .fx .Node , cat_node : torch .fx .Node
317+ ) -> None :
318+ """Advance a quantize op above a cat by quantizing each cat input."""
319+ graph = quant_node .graph
320+ quant_params = quant_node .args [1 :]
321+
322+ cat_inputs = cat_node .args [0 ]
323+ assert isinstance (cat_inputs , (list , tuple ))
324+
325+ new_inputs : list [torch .fx .Node ] = []
326+ for inp in cat_inputs :
327+ # cat concatenates tensors, so every input must be a node.
328+ assert isinstance (inp , torch .fx .Node )
329+
330+ with graph .inserting_before (cat_node ):
331+ new_quant = graph .call_function (
332+ # pyre-ignore[6]
333+ quant_node .target ,
334+ args = (inp , * quant_params ),
335+ )
336+ # This copies the fp32 input's meta, so meta["val"] keeps the
337+ # fp32 dtype rather than the quantized output dtype. That's fine:
338+ # nothing in this pass reads dtype from meta (only shape, which
339+ # is correct), and call() re-runs super().call() to re-propagate
340+ # fake tensors, making meta dtype-consistent before we return.
341+ new_quant .meta = inp .meta .copy ()
342+ new_inputs .append (new_quant )
343+
344+ dim = get_arg (cat_node , "dim" , int )
345+ with graph .inserting_before (quant_node ):
346+ new_cat = graph .call_function (
347+ # pyre-ignore[6]
348+ cat_node .target ,
349+ args = (new_inputs , dim ),
350+ )
351+ new_cat .meta = quant_node .meta .copy ()
352+
353+ quant_node .replace_all_uses_with (new_cat )
354+ graph .erase_node (quant_node )
355+
305356 def advance_quantize_op (self , graph_module : torch .fx .GraphModule ) -> bool :
306357 graph = graph_module .graph
307358 modified = False
@@ -314,6 +365,17 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
314365 ):
315366 continue
316367
368+ inp = node .args [0 ]
369+ if (
370+ isinstance (inp , torch .fx .Node )
371+ and get_overload_packet (inp .target )
372+ in (exir_ops .edge .aten .cat , torch .ops .aten .cat )
373+ and len (inp .users ) == 1
374+ ):
375+ self ._advance_above_cat (node , inp )
376+ modified = True
377+ continue
378+
317379 if not self .advancing_feasible (node ):
318380 continue
319381
0 commit comments