@@ -819,68 +819,86 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
819819
820820
821821@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
822- class FuseMulScalarIntoDequantPass (ExportPass ):
822+ class FuseMulScalarIntoDequantPass (RemoveOrReplacePassInterface ):
823823 """
824824 Looks for the pattern where aten.mul.Scalar is multiplying the
825825 outputs of dequantize. If found, updates the dequant scale
826826 to reflect the multiplication and removes the mul node.
827827 """
828828
829- def attempt_fusion (
830- self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
831- ) -> None :
832- if node .target not in {
833- exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
834- exir_ops .edge .cadence .dequantize_per_tensor .default ,
835- }:
836- return
829+ @property
830+ def targets (self ) -> list [EdgeOpOverload ]:
831+ return [exir_ops .edge .aten .mul .Scalar ]
837832
838- # ensure that the single user of dequant is aten.mul.Scalar
839- user = list (node .users .keys ())[0 ]
840- if len (node .users ) != 1 or user .target != exir_ops .edge .aten .mul .Scalar :
841- return
833+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
834+ # Ensure that the single user of dequant is aten.mul.Scalar
835+ mul_node = node
836+ input_nodes = mul_node .all_input_nodes
837+ if len (input_nodes ) != 1 or len (input_nodes [0 ].users ) != 1 :
838+ return False
842839
843- # ensure that the other arg to mul is a node (i.e. not a constant)
844- if len (user .args ) > 1 and isinstance (user .args [1 ], torch .fx .Node ):
845- return
840+ dequant_node = input_nodes [0 ]
846841
847- new_deq_args = list ( node . args )
848- assert isinstance ( node . args [ 1 ], Number )
849- assert isinstance ( user . args [ 1 ], Number )
850- # pyre-ignore[58]: Unsupported operand *
851- new_deq_args [ 1 ] = node . args [ 1 ] * user . args [ 1 ]
842+ if dequant_node . target not in [
843+ exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default ,
844+ exir_ops . edge . cadence . dequantize_per_tensor . default ,
845+ ]:
846+ return False
852847
853- logging .debug (
854- f"Fused { node } and { user } into { node } . Updated scale from { node .args [1 ]} to { new_deq_args [1 ]} "
855- )
848+ if len (mul_node .args ) <= 1 or isinstance (mul_node .args [1 ], torch .fx .Node ):
849+ return False
856850
857- user .replace_all_uses_with (node )
858- node .args = tuple (new_deq_args )
851+ new_deq_args = list (dequant_node .args )
852+ assert isinstance (dequant_node .args [1 ], Number )
853+ assert isinstance (mul_node .args [1 ], Number )
854+ # pyre-ignore[58]: Unsupported operand *
855+ new_deq_args [1 ] = dequant_node .args [1 ] * mul_node .args [1 ]
859856
860- graph_module .graph .erase_node (user )
857+ # Replace all uses of mul with the dequant node
858+ mul_node .replace_all_uses_with (dequant_node )
859+ # Update the dequant node's args with the new scale
860+ dequant_node .args = tuple (new_deq_args )
861861
862- graph_module .recompile ()
862+ # Erase the mul node
863+ mul_node .graph .erase_node (mul_node )
863864
864- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
865- for node in graph_module .graph .nodes :
866- self .attempt_fusion (graph_module , node )
867- result = super ().call (graph_module )
868- return result
865+ logging .debug (
866+ f"Fused { dequant_node } and { mul_node } into { dequant_node } . Updated scale from { dequant_node .args [1 ]} to { new_deq_args [1 ]} "
867+ )
868+ return True
869869
870870
871871@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
872- class FuseMulTensorIntoQuantPass (ExportPass ):
872+ class FuseMulTensorIntoQuantPass (RemoveOrReplacePassInterface ):
873873 """
874874 Looks for the pattern where aten.mul.Tensor is followed by quant node.
875875 If found, updates the quant scale to reflect the multiplication and
876876 removes the mul node.
877877 """
878878
879- def attempt_fusion (
880- self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
881- ) -> None :
882- if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
883- return
879+ @property
880+ def targets (self ) -> list [EdgeOpOverload ]:
881+ return [exir_ops .edge .aten .mul .Tensor ]
882+
883+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
884+
885+ mul_node = node
886+ if len (mul_node .users ) != 1 :
887+ return False
888+
889+ user = next (iter (mul_node .users ))
890+ user_input_nodes = user .all_input_nodes
891+ if len (user_input_nodes ) != 1 :
892+ return False
893+
894+ if user .target not in [
895+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
896+ exir_ops .edge .cadence .quantize_per_tensor .default ,
897+ ]:
898+ return False
899+
900+ # Alias for readability.
901+ quant_node = user
884902
885903 first_arg = cast (torch .fx .Node , mul_node .args [0 ])
886904 second_arg = cast (torch .fx .Node , mul_node .args [1 ])
@@ -896,22 +914,11 @@ def attempt_fusion(
896914 input_node = second_arg
897915 else :
898916 # Full node is not found, skip.
899- return
917+ return False
900918
901919 # Ensure that the mul op does not do any broadcasting.
902- if input_node .meta ["val" ].shape != mul_node .meta ["val" ].shape :
903- return
904-
905- mul_user = list (mul_node .users .keys ())[0 ]
906-
907- # Ensure only the expected quant ops are using the current mul op.
908- if mul_user .target not in {
909- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
910- exir_ops .edge .cadence .quantize_per_tensor .default ,
911- }:
912- return
913-
914- quant_node = mul_user
920+ if input_node .meta ["val" ].shape != node .meta ["val" ].shape :
921+ return False
915922
916923 # Calculate the new scale value.
917924 old_scale = quant_node .args [1 ]
@@ -925,42 +932,41 @@ def attempt_fusion(
925932 new_scale = old_scale / mul_scalar
926933 q = zp + x / new_scale
927934 """
935+
936+ # Cannot fuse if either value is zero:
937+ # - mul_scalar == 0 would cause division by zero computing new_scale
938+ # - old_scale == 0 would result in new_scale = 0, causing division by zero during quantization
939+ if mul_scalar == 0 or old_scale == 0 :
940+ return False
928941 new_scale = float (old_scale ) / float (mul_scalar )
929942
930943 logging .debug (
931- f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
944+ f"Fused { node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
932945 )
933946
934947 # Update quant node input and scale.
935948 old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
936- new_quant_input = cast ( torch . fx . Node , mul_node . args [ 0 ])
949+ new_quant_input = input_node
937950 quant_node .replace_input_with (old_quant_input , new_quant_input )
938951 quant_node .update_arg (1 , new_scale )
939952
940- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
941- for node in graph_module .graph .find_nodes (
942- op = "call_function" , target = exir_ops .edge .aten .mul .Tensor
943- ):
944- self .attempt_fusion (graph_module , node )
945- graph_module .graph .eliminate_dead_code ()
946- return super ().call (graph_module )
953+ return True
947954
948955
949956@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
950- class FuseMulTensorIntoDequantPass (ExportPass ):
957+ class FuseMulTensorIntoDequantPass (RemoveOrReplacePassInterface ):
951958 """
952959 Looks for the pattern where aten.mul is multiplying the outputs of dequantize
953960 and aten.full, or vice versa. If found, updates the dequant scale to reflect
954961 the multiplication and removes the full and mul nodes.
955962 """
956963
957- def attempt_fusion (
958- self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
959- ) -> None :
960- if node .target != exir_ops .edge .aten .mul .Tensor :
961- return
964+ @property
965+ def targets (self ) -> list [EdgeOpOverload ]:
966+ return [exir_ops .edge .aten .mul .Tensor ]
962967
963- # ensure that one of the args to mul is dequantize and the other is aten.full
968+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
969+ # Ensure that one of the args to mul is dequantize and the other is aten.full
964970 dequant_nodes = [
965971 arg
966972 for arg in node .args
@@ -980,14 +986,14 @@ def attempt_fusion(
980986 ]
981987
982988 if len (dequant_nodes ) != 1 or len (multiplier_nodes ) != 1 :
983- return
989+ return False
984990
985991 deq_node = dequant_nodes [0 ]
986992 mplier_node = multiplier_nodes [0 ]
987993
988- # ensure that dequant and full don't have any other users
994+ # Ensure that dequant and full don't have any other users
989995 if len (deq_node .users ) > 1 or len (mplier_node .users ) > 1 :
990- return
996+ return False
991997
992998 new_deq_args = list (deq_node .args )
993999 assert isinstance (deq_node .args [1 ], Number )
@@ -999,18 +1005,16 @@ def attempt_fusion(
9991005 f"Fused { node } and { mplier_node } into { deq_node } . Updated scale from { deq_node .args [1 ]} to { new_deq_args [1 ]} "
10001006 )
10011007
1008+ # Replace all uses of the mul node with the dequant node
10021009 node .replace_all_uses_with (deq_node )
1010+ # Update the dequant node's args with the new scale
10031011 deq_node .args = tuple (new_deq_args )
10041012
1005- graph_module . graph . erase_node ( node )
1006- graph_module .graph .erase_node (mplier_node )
1007- graph_module . recompile ( )
1013+ # Erase the mul and full nodes
1014+ node .graph .erase_node (node )
1015+ node . graph . erase_node ( mplier_node )
10081016
1009- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
1010- for node in graph_module .graph .nodes :
1011- self .attempt_fusion (graph_module , node )
1012- result = super ().call (graph_module )
1013- return result
1017+ return True
10141018
10151019
10161020@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
0 commit comments