1212from typing import List , Optional , Tuple , Union
1313
1414import torch
15- from executorch .backends .cadence .aot .quantizer .utils import get_bias_qparams
16-
15+ from executorch .backends .cadence .aot .pass_utils import get_arg , replace_with_op
16+ from executorch .backends .cadence .aot .quantizer .pattern_utils import (
17+ DQ_PER_TENSOR ,
18+ find_quant_user ,
19+ fuse_conv ,
20+ fuse_linear ,
21+ fuse_matmul ,
22+ insert_node_with_meta ,
23+ )
24+ from executorch .backends .cadence .aot .quantizer .utils import (
25+ check_out_zero_point_is_min_range ,
26+ get_bias_qparams ,
27+ )
1728from torch import fx
1829from torch ._ops import OpOverload
1930from torchao .quantization .pt2e .quantizer import (
@@ -131,6 +142,41 @@ def get_anchors(
131142 def replacement_op (self ) -> OpOverload :
132143 return torch .ops .cadence .quantized_linear .per_tensor
133144
145+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
146+ assert anchor_node .target == torch .ops .aten .addmm .default
147+ # addmm(bias, input, weight)
148+ bias_node = anchor_node .args [0 ]
149+ assert isinstance (bias_node , fx .Node )
150+ dq_input = get_arg (anchor_node , "mat1" , fx .Node )
151+ if dq_input .target != DQ_PER_TENSOR :
152+ return None
153+ dq_weight = get_arg (anchor_node , "mat2" , fx .Node )
154+ if dq_weight .target != DQ_PER_TENSOR :
155+ return None
156+ quant_node = find_quant_user (anchor_node )
157+ if quant_node is None :
158+ return None
159+ dq_bias = bias_node if bias_node .target == DQ_PER_TENSOR else None
160+ weight_q = get_arg (dq_weight , "input" , fx .Node )
161+ transposed = insert_node_with_meta (
162+ gm ,
163+ torch .ops .aten .transpose .int ,
164+ (weight_q , 0 , 1 ),
165+ None ,
166+ anchor_node ,
167+ weight_q ,
168+ )
169+ return fuse_linear (
170+ gm ,
171+ dq_input ,
172+ dq_weight ,
173+ dq_bias ,
174+ quant_node ,
175+ anchor_node ,
176+ self .replacement_op (),
177+ weight_q = transposed ,
178+ )
179+
134180
135181class AddPattern (QuantizationPattern ):
136182 def partition_types (self ) -> List [OpOverload ]:
@@ -169,6 +215,33 @@ def get_anchors(
169215 def replacement_op (self ) -> OpOverload :
170216 return torch .ops .cadence .quantized_add .per_tensor
171217
218+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
219+ # Skip if alpha kwarg is present — changes add semantics.
220+ if anchor_node .kwargs :
221+ return None
222+ dq0 = anchor_node .args [0 ]
223+ if not isinstance (dq0 , fx .Node ) or dq0 .target != DQ_PER_TENSOR :
224+ return None
225+ dq1 = anchor_node .args [1 ]
226+ if not isinstance (dq1 , fx .Node ) or dq1 .target != DQ_PER_TENSOR :
227+ return None
228+ quant_node = find_quant_user (anchor_node )
229+ if quant_node is None :
230+ return None
231+ args = (
232+ get_arg (dq0 , "input" , fx .Node ),
233+ get_arg (dq0 , "scale" , float ),
234+ get_arg (dq0 , "zero_point" , int ),
235+ get_arg (dq1 , "input" , fx .Node ),
236+ get_arg (dq1 , "scale" , float ),
237+ get_arg (dq1 , "zero_point" , int ),
238+ get_arg (quant_node , "scale" , float ),
239+ get_arg (quant_node , "zero_point" , int ),
240+ )
241+ return replace_with_op (
242+ gm , anchor_node , self .replacement_op (), args , {}, quant_node
243+ )
244+
172245
173246# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops
174247class AddReluBasePattern (QuantizationPattern ):
@@ -212,6 +285,46 @@ def get_anchors(
212285 def replacement_op (self ) -> OpOverload :
213286 return torch .ops .cadence .quantized_add .per_tensor
214287
288+ def anchor_ops (self ) -> tuple [OpOverload , ...]:
289+ return (torch .ops .aten .add .Tensor ,)
290+
291+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
292+ add_users = list (anchor_node .users )
293+ if len (add_users ) != 1 :
294+ return None
295+ relu_node = add_users [0 ]
296+ if relu_node .target != self .partition_types ()[1 ]:
297+ return None
298+ if len (anchor_node .kwargs ) > 0 :
299+ return None
300+ dq0 = anchor_node .args [0 ]
301+ if not isinstance (dq0 , fx .Node ) or dq0 .target != DQ_PER_TENSOR :
302+ return None
303+ dq1 = anchor_node .args [1 ]
304+ if not isinstance (dq1 , fx .Node ) or dq1 .target != DQ_PER_TENSOR :
305+ return None
306+ quant_node = find_quant_user (relu_node )
307+ if quant_node is None :
308+ return None
309+ if not check_out_zero_point_is_min_range (
310+ get_arg (quant_node , "zero_point" , int ),
311+ get_arg (quant_node , "dtype" , torch .dtype ),
312+ ):
313+ return None
314+ args = (
315+ get_arg (dq0 , "input" , fx .Node ),
316+ get_arg (dq0 , "scale" , float ),
317+ get_arg (dq0 , "zero_point" , int ),
318+ get_arg (dq1 , "input" , fx .Node ),
319+ get_arg (dq1 , "scale" , float ),
320+ get_arg (dq1 , "zero_point" , int ),
321+ get_arg (quant_node , "scale" , float ),
322+ get_arg (quant_node , "zero_point" , int ),
323+ )
324+ return replace_with_op (
325+ gm , anchor_node , self .replacement_op (), args , {}, quant_node
326+ )
327+
215328
216329# Add + regular relu op fusion
217330class AddReluPattern0 (AddReluBasePattern ):
@@ -250,6 +363,18 @@ def replacement_op(self) -> OpOverload:
250363 # we just need to change the name of the op
251364 return torch .ops .cadence .quantized_matmul .default
252365
366+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
367+ dq0 = anchor_node .args [0 ]
368+ if not isinstance (dq0 , fx .Node ) or dq0 .target != DQ_PER_TENSOR :
369+ return None
370+ dq1 = anchor_node .args [1 ]
371+ if not isinstance (dq1 , fx .Node ) or dq1 .target != DQ_PER_TENSOR :
372+ return None
373+ quant_node = find_quant_user (anchor_node )
374+ if quant_node is None :
375+ return None
376+ return fuse_matmul (gm , anchor_node , dq0 , dq1 , quant_node , self .replacement_op ())
377+
253378
254379class CatPattern (QuantizationPattern ):
255380 def partition_types (self ) -> List [OpOverload ]:
@@ -299,6 +424,25 @@ def get_anchors(
299424 def replacement_op (self ) -> OpOverload :
300425 return torch .ops .aten .cat .default
301426
427+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
428+ cat_inputs = anchor_node .args [0 ]
429+ if not isinstance (cat_inputs , (list , tuple )) or not cat_inputs :
430+ return None
431+ inputs_q = []
432+ for inp in cat_inputs :
433+ if not isinstance (inp , fx .Node ) or inp .target != DQ_PER_TENSOR :
434+ return None
435+ inputs_q .append (get_arg (inp , "input" , fx .Node ))
436+ quant_node = find_quant_user (anchor_node )
437+ if quant_node is None :
438+ return None
439+ dim = get_arg (anchor_node , "dim" , int )
440+ args = (inputs_q ,)
441+ kwargs = {"dim" : dim }
442+ return replace_with_op (
443+ gm , anchor_node , self .replacement_op (), args , kwargs , quant_node
444+ )
445+
302446
303447class Conv1dPattern (QuantizationPattern ):
304448 def partition_types (self ) -> List [OpOverload ]:
@@ -341,6 +485,18 @@ def get_anchors(
341485 def replacement_op (self ) -> OpOverload :
342486 return torch .ops .cadence .quantized_conv1d_ncl .per_tensor
343487
488+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
489+ dq_input = anchor_node .args [0 ]
490+ if not isinstance (dq_input , fx .Node ) or dq_input .target != DQ_PER_TENSOR :
491+ return None
492+ dq_weight = anchor_node .args [1 ]
493+ if not isinstance (dq_weight , fx .Node ) or dq_weight .target != DQ_PER_TENSOR :
494+ return None
495+ quant_node = find_quant_user (anchor_node )
496+ if quant_node is None :
497+ return None
498+ return fuse_conv (self , gm , anchor_node , dq_input , dq_weight , quant_node )
499+
344500
345501class Conv2dPattern (QuantizationPattern ):
346502 def partition_types (self ) -> List [OpOverload ]:
@@ -383,6 +539,18 @@ def get_anchors(
383539 def replacement_op (self ) -> OpOverload :
384540 return torch .ops .cadence .quantized_conv2d_nchw .per_tensor
385541
542+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
543+ dq_input = anchor_node .args [0 ]
544+ if not isinstance (dq_input , fx .Node ) or dq_input .target != DQ_PER_TENSOR :
545+ return None
546+ dq_weight = anchor_node .args [1 ]
547+ if not isinstance (dq_weight , fx .Node ) or dq_weight .target != DQ_PER_TENSOR :
548+ return None
549+ quant_node = find_quant_user (anchor_node )
550+ if quant_node is None :
551+ return None
552+ return fuse_conv (self , gm , anchor_node , dq_input , dq_weight , quant_node )
553+
386554
387555class LayerNormPattern (QuantizationPattern ):
388556 def partition_types (self ) -> List [OpOverload ]:
@@ -421,6 +589,61 @@ def get_anchors(
421589 def replacement_op (self ) -> OpOverload :
422590 return torch .ops .cadence .quantized_layer_norm .per_tensor
423591
592+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
593+ dq_input = anchor_node .args [0 ]
594+ if not isinstance (dq_input , fx .Node ) or dq_input .target != DQ_PER_TENSOR :
595+ return None
596+ quant_node = find_quant_user (anchor_node )
597+ if quant_node is None :
598+ return None
599+ scale = get_arg (dq_input , "scale" , float )
600+ zero_point = get_arg (dq_input , "zero_point" , int )
601+ normalized_shape = anchor_node .args [1 ]
602+ assert isinstance (normalized_shape , list )
603+ weight = (
604+ anchor_node .args [2 ]
605+ if len (anchor_node .args ) > 2 and anchor_node .args [2 ]
606+ else None
607+ )
608+ bias = (
609+ anchor_node .args [3 ]
610+ if len (anchor_node .args ) > 3 and anchor_node .args [3 ]
611+ else None
612+ )
613+ input_q = get_arg (dq_input , "input" , fx .Node )
614+ # Default weight=1 and bias=0 must be float32 — cadence::quantized_layer_norm
615+ # expects float affine parameters, not quantized values.
616+ if not weight :
617+ weight = insert_node_with_meta (
618+ gm ,
619+ torch .ops .aten .full .default ,
620+ (normalized_shape , 1 ),
621+ {"dtype" : torch .float32 },
622+ anchor_node ,
623+ input_q ,
624+ )
625+ if not bias :
626+ bias = insert_node_with_meta (
627+ gm ,
628+ torch .ops .aten .full .default ,
629+ (normalized_shape , 0 ),
630+ {"dtype" : torch .float32 },
631+ anchor_node ,
632+ input_q ,
633+ )
634+ args = (input_q , scale , zero_point )
635+ kwargs = {
636+ "normalized_shape" : normalized_shape ,
637+ "weight" : weight ,
638+ "bias" : bias ,
639+ "eps" : get_arg (anchor_node , "eps" , float ),
640+ "output_scale" : get_arg (quant_node , "scale" , float ),
641+ "output_zero_point" : get_arg (quant_node , "zero_point" , int ),
642+ }
643+ return replace_with_op (
644+ gm , anchor_node , self .replacement_op (), args , kwargs , quant_node
645+ )
646+
424647
425648class LinearPattern (QuantizationPattern ):
426649 def partition_types (self ) -> List [OpOverload ]:
@@ -463,6 +686,31 @@ def get_anchors(
463686 def replacement_op (self ) -> OpOverload :
464687 return torch .ops .cadence .quantized_linear .per_tensor
465688
689+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
690+ dq_input = anchor_node .args [0 ]
691+ if not isinstance (dq_input , fx .Node ) or dq_input .target != DQ_PER_TENSOR :
692+ return None
693+ dq_weight = anchor_node .args [1 ]
694+ if not isinstance (dq_weight , fx .Node ) or dq_weight .target != DQ_PER_TENSOR :
695+ return None
696+ quant_node = find_quant_user (anchor_node )
697+ if quant_node is None :
698+ return None
699+ dq_bias : fx .Node | None = None
700+ if len (anchor_node .args ) > 2 :
701+ bias_arg = anchor_node .args [2 ]
702+ if isinstance (bias_arg , fx .Node ) and bias_arg .target == DQ_PER_TENSOR :
703+ dq_bias = bias_arg
704+ return fuse_linear (
705+ gm ,
706+ dq_input ,
707+ dq_weight ,
708+ dq_bias ,
709+ quant_node ,
710+ anchor_node ,
711+ self .replacement_op (),
712+ )
713+
466714
467715class MatmulPattern (QuantizationPattern ):
468716 def partition_types (self ) -> List [OpOverload ]:
@@ -488,6 +736,18 @@ def replacement_op(self) -> OpOverload:
488736 # TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op
489737 return torch .ops .cadence .quantized_matmul .default
490738
739+ def fuse (self , gm : fx .GraphModule , anchor_node : fx .Node ) -> fx .Node | None :
740+ dq0 = anchor_node .args [0 ]
741+ if not isinstance (dq0 , fx .Node ) or dq0 .target != DQ_PER_TENSOR :
742+ return None
743+ dq1 = anchor_node .args [1 ]
744+ if not isinstance (dq1 , fx .Node ) or dq1 .target != DQ_PER_TENSOR :
745+ return None
746+ quant_node = find_quant_user (anchor_node )
747+ if quant_node is None :
748+ return None
749+ return fuse_matmul (gm , anchor_node , dq0 , dq1 , quant_node , self .replacement_op ())
750+
491751
492752class MaxPool2dPattern (QuantizationPattern ):
493753 """
0 commit comments