Skip to content

Commit 7412efe

Browse files
authored
Fix issue in not checking dequant node parent for mul node
Differential Revision: D91337631 Pull Request resolved: #16832
1 parent 2d6c2ef commit 7412efe

2 files changed

Lines changed: 156 additions & 92 deletions

File tree

backends/cadence/aot/fuse_ops.py

Lines changed: 84 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -819,68 +819,86 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
819819

820820

821821
@register_cadence_pass(CadencePassAttribute(opt_level=1))
822-
class FuseMulScalarIntoDequantPass(ExportPass):
822+
class FuseMulScalarIntoDequantPass(RemoveOrReplacePassInterface):
823823
"""
824824
Looks for the pattern where aten.mul.Scalar is multiplying the
825825
outputs of dequantize. If found, updates the dequant scale
826826
to reflect the multiplication and removes the mul node.
827827
"""
828828

829-
def attempt_fusion(
830-
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
831-
) -> None:
832-
if node.target not in {
833-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
834-
exir_ops.edge.cadence.dequantize_per_tensor.default,
835-
}:
836-
return
829+
@property
830+
def targets(self) -> list[EdgeOpOverload]:
831+
return [exir_ops.edge.aten.mul.Scalar]
837832

838-
# ensure that the single user of dequant is aten.mul.Scalar
839-
user = list(node.users.keys())[0]
840-
if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar:
841-
return
833+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
834+
# Ensure that the single user of dequant is aten.mul.Scalar
835+
mul_node = node
836+
input_nodes = mul_node.all_input_nodes
837+
if len(input_nodes) != 1 or len(input_nodes[0].users) != 1:
838+
return False
842839

843-
# ensure that the other arg to mul is a node (i.e. not a constant)
844-
if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node):
845-
return
840+
dequant_node = input_nodes[0]
846841

847-
new_deq_args = list(node.args)
848-
assert isinstance(node.args[1], Number)
849-
assert isinstance(user.args[1], Number)
850-
# pyre-ignore[58]: Unsupported operand *
851-
new_deq_args[1] = node.args[1] * user.args[1]
842+
if dequant_node.target not in [
843+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
844+
exir_ops.edge.cadence.dequantize_per_tensor.default,
845+
]:
846+
return False
852847

853-
logging.debug(
854-
f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}"
855-
)
848+
if len(mul_node.args) <= 1 or isinstance(mul_node.args[1], torch.fx.Node):
849+
return False
856850

857-
user.replace_all_uses_with(node)
858-
node.args = tuple(new_deq_args)
851+
new_deq_args = list(dequant_node.args)
852+
assert isinstance(dequant_node.args[1], Number)
853+
assert isinstance(mul_node.args[1], Number)
854+
# pyre-ignore[58]: Unsupported operand *
855+
new_deq_args[1] = dequant_node.args[1] * mul_node.args[1]
859856

860-
graph_module.graph.erase_node(user)
857+
# Replace all uses of mul with the dequant node
858+
mul_node.replace_all_uses_with(dequant_node)
859+
# Update the dequant node's args with the new scale
860+
dequant_node.args = tuple(new_deq_args)
861861

862-
graph_module.recompile()
862+
# Erase the mul node
863+
mul_node.graph.erase_node(mul_node)
863864

864-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
865-
for node in graph_module.graph.nodes:
866-
self.attempt_fusion(graph_module, node)
867-
result = super().call(graph_module)
868-
return result
865+
logging.debug(
866+
f"Fused {dequant_node} and {mul_node} into {dequant_node}. Updated scale from {dequant_node.args[1]} to {new_deq_args[1]}"
867+
)
868+
return True
869869

870870

871871
@register_cadence_pass(CadencePassAttribute(opt_level=1))
872-
class FuseMulTensorIntoQuantPass(ExportPass):
872+
class FuseMulTensorIntoQuantPass(RemoveOrReplacePassInterface):
873873
"""
874874
Looks for the pattern where aten.mul.Tensor is followed by quant node.
875875
If found, updates the quant scale to reflect the multiplication and
876876
removes the mul node.
877877
"""
878878

879-
def attempt_fusion(
880-
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
881-
) -> None:
882-
if len(mul_node.args) != 2 or len(mul_node.users) != 1:
883-
return
879+
@property
880+
def targets(self) -> list[EdgeOpOverload]:
881+
return [exir_ops.edge.aten.mul.Tensor]
882+
883+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
884+
885+
mul_node = node
886+
if len(mul_node.users) != 1:
887+
return False
888+
889+
user = next(iter(mul_node.users))
890+
user_input_nodes = user.all_input_nodes
891+
if len(user_input_nodes) != 1:
892+
return False
893+
894+
if user.target not in [
895+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
896+
exir_ops.edge.cadence.quantize_per_tensor.default,
897+
]:
898+
return False
899+
900+
# Alias for readability.
901+
quant_node = user
884902

885903
first_arg = cast(torch.fx.Node, mul_node.args[0])
886904
second_arg = cast(torch.fx.Node, mul_node.args[1])
@@ -896,22 +914,11 @@ def attempt_fusion(
896914
input_node = second_arg
897915
else:
898916
# Full node is not found, skip.
899-
return
917+
return False
900918

901919
# Ensure that the mul op does not do any broadcasting.
902-
if input_node.meta["val"].shape != mul_node.meta["val"].shape:
903-
return
904-
905-
mul_user = list(mul_node.users.keys())[0]
906-
907-
# Ensure only the expected quant ops are using the current mul op.
908-
if mul_user.target not in {
909-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
910-
exir_ops.edge.cadence.quantize_per_tensor.default,
911-
}:
912-
return
913-
914-
quant_node = mul_user
920+
if input_node.meta["val"].shape != node.meta["val"].shape:
921+
return False
915922

916923
# Calculate the new scale value.
917924
old_scale = quant_node.args[1]
@@ -925,42 +932,41 @@ def attempt_fusion(
925932
new_scale = old_scale / mul_scalar
926933
q = zp + x / new_scale
927934
"""
935+
936+
# Cannot fuse if either value is zero:
937+
# - mul_scalar == 0 would cause division by zero computing new_scale
938+
# - old_scale == 0 would result in new_scale = 0, causing division by zero during quantization
939+
if mul_scalar == 0 or old_scale == 0:
940+
return False
928941
new_scale = float(old_scale) / float(mul_scalar)
929942

930943
logging.debug(
931-
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
944+
f"Fused {node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
932945
)
933946

934947
# Update quant node input and scale.
935948
old_quant_input = cast(torch.fx.Node, quant_node.args[0])
936-
new_quant_input = cast(torch.fx.Node, mul_node.args[0])
949+
new_quant_input = input_node
937950
quant_node.replace_input_with(old_quant_input, new_quant_input)
938951
quant_node.update_arg(1, new_scale)
939952

940-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
941-
for node in graph_module.graph.find_nodes(
942-
op="call_function", target=exir_ops.edge.aten.mul.Tensor
943-
):
944-
self.attempt_fusion(graph_module, node)
945-
graph_module.graph.eliminate_dead_code()
946-
return super().call(graph_module)
953+
return True
947954

948955

949956
@register_cadence_pass(CadencePassAttribute(opt_level=1))
950-
class FuseMulTensorIntoDequantPass(ExportPass):
957+
class FuseMulTensorIntoDequantPass(RemoveOrReplacePassInterface):
951958
"""
952959
Looks for the pattern where aten.mul is multiplying the outputs of dequantize
953960
and aten.full, or vice versa. If found, updates the dequant scale to reflect
954961
the multiplication and removes the full and mul nodes.
955962
"""
956963

957-
def attempt_fusion(
958-
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
959-
) -> None:
960-
if node.target != exir_ops.edge.aten.mul.Tensor:
961-
return
964+
@property
965+
def targets(self) -> list[EdgeOpOverload]:
966+
return [exir_ops.edge.aten.mul.Tensor]
962967

963-
# ensure that one of the args to mul is dequantize and the other is aten.full
968+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
969+
# Ensure that one of the args to mul is dequantize and the other is aten.full
964970
dequant_nodes = [
965971
arg
966972
for arg in node.args
@@ -980,14 +986,14 @@ def attempt_fusion(
980986
]
981987

982988
if len(dequant_nodes) != 1 or len(multiplier_nodes) != 1:
983-
return
989+
return False
984990

985991
deq_node = dequant_nodes[0]
986992
mplier_node = multiplier_nodes[0]
987993

988-
# ensure that dequant and full don't have any other users
994+
# Ensure that dequant and full don't have any other users
989995
if len(deq_node.users) > 1 or len(mplier_node.users) > 1:
990-
return
996+
return False
991997

992998
new_deq_args = list(deq_node.args)
993999
assert isinstance(deq_node.args[1], Number)
@@ -999,18 +1005,16 @@ def attempt_fusion(
9991005
f"Fused {node} and {mplier_node} into {deq_node}. Updated scale from {deq_node.args[1]} to {new_deq_args[1]}"
10001006
)
10011007

1008+
# Replace all uses of the mul node with the dequant node
10021009
node.replace_all_uses_with(deq_node)
1010+
# Update the dequant node's args with the new scale
10031011
deq_node.args = tuple(new_deq_args)
10041012

1005-
graph_module.graph.erase_node(node)
1006-
graph_module.graph.erase_node(mplier_node)
1007-
graph_module.recompile()
1013+
# Erase the mul and full nodes
1014+
node.graph.erase_node(node)
1015+
node.graph.erase_node(mplier_node)
10081016

1009-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1010-
for node in graph_module.graph.nodes:
1011-
self.attempt_fusion(graph_module, node)
1012-
result = super().call(graph_module)
1013-
return result
1017+
return True
10141018

10151019

10161020
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,8 @@ def test_fuse_mul_into_dequant(self) -> None:
602602
FULL_VALUE: Final[float] = 3
603603

604604
builder = GraphBuilder()
605-
x = builder.placeholder("x", torch.randn(*INPUT_SHAPE, dtype=torch.float32))
605+
x_input = torch.randint(low=0, high=255, size=INPUT_SHAPE, dtype=torch.uint8)
606+
x = builder.placeholder("x", x_input)
606607
dequant = builder.call_operator(
607608
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
608609
args=(x, DEQUANT_SCALE, 0, 0, 255, torch.uint8),
@@ -617,8 +618,17 @@ def test_fuse_mul_into_dequant(self) -> None:
617618
)
618619
builder.output([mul])
619620
original_graph = builder.get_graph_module()
621+
gm_before = copy.deepcopy(original_graph)
622+
620623
p = FuseMulTensorIntoDequantPass()
621-
converted_graph = cast(PassResult, p(original_graph)).graph_module
624+
result = cast(PassResult, p(original_graph))
625+
self.assertTrue(result.modified)
626+
converted_graph = result.graph_module
627+
628+
# Validate numerical accuracy
629+
validate_numerics(
630+
gm_before, converted_graph, (x_input,), "FuseMulTensorIntoDequantPass"
631+
)
622632

623633
# verify that the mul and full ops were removed
624634
self.check_op_counts(
@@ -640,12 +650,49 @@ def test_fuse_mul_into_dequant(self) -> None:
640650
deq_scale = node.args[1]
641651
self.assertEqual(deq_scale, DEQUANT_SCALE * FULL_VALUE)
642652

653+
def test_fuse_mul_into_dequant_no_match(self) -> None:
654+
"""
655+
Test that FuseMulTensorIntoDequantPass does NOT modify the graph
656+
when the mul node's inputs are not dequant + full.
657+
"""
658+
INPUT_SHAPE: Final[List[int]] = [4, 32]
659+
660+
builder = GraphBuilder()
661+
# Create two regular placeholder inputs (not dequant outputs)
662+
x_input = torch.randn(*INPUT_SHAPE, dtype=torch.float32)
663+
y_input = torch.randn(*INPUT_SHAPE, dtype=torch.float32)
664+
x = builder.placeholder("x", x_input)
665+
y = builder.placeholder("y", y_input)
666+
667+
# Mul of two placeholders - no dequant node involved
668+
mul = builder.call_operator(
669+
op=exir_ops.edge.aten.mul.Tensor,
670+
args=(x, y),
671+
)
672+
builder.output([mul])
673+
original_graph = builder.get_graph_module()
674+
675+
p = FuseMulTensorIntoDequantPass()
676+
result = cast(PassResult, p(original_graph))
677+
678+
# The pass should NOT modify the graph since there's no dequant node
679+
self.assertFalse(result.modified)
680+
681+
# Verify that the mul op is still present
682+
self.check_op_counts(
683+
result.graph_module,
684+
expected_op_counts={
685+
exir_ops.edge.aten.mul.Tensor: 1,
686+
},
687+
)
688+
643689
def test_fuse_mul_scalar_into_dequant(self) -> None:
644690
dequant_scale = 0.006
645691
mul_value = 0.3
646692

647693
builder = GraphBuilder()
648-
x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
694+
x_input = torch.randn(2, 3, 4, dtype=torch.float32)
695+
x = builder.placeholder("x", x_input)
649696
quant = builder.call_operator(
650697
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
651698
args=(x, 1, 0, -128, 127, torch.int8),
@@ -660,8 +707,17 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
660707
)
661708
builder.output([mul_scalar])
662709
original_graph = builder.get_graph_module()
710+
gm_before = copy.deepcopy(original_graph)
711+
663712
p = FuseMulScalarIntoDequantPass()
664-
converted_graph = cast(PassResult, p(original_graph)).graph_module
713+
result = cast(PassResult, p(original_graph))
714+
self.assertTrue(result.modified)
715+
converted_graph = result.graph_module
716+
717+
# Validate numerical accuracy
718+
validate_numerics(
719+
gm_before, converted_graph, (x_input,), "FuseMulScalarIntoDequantPass"
720+
)
665721

666722
# verify that the mul and full ops were removed
667723
self.check_op_counts(
@@ -687,7 +743,8 @@ def test_fuse_mul_into_quant(self) -> None:
687743
mul_value = 10
688744

689745
builder = GraphBuilder()
690-
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
746+
x_input = torch.randn(4, 32, dtype=torch.float32)
747+
x = builder.placeholder("x", x_input)
691748
full = builder.call_operator(
692749
op=exir_ops.edge.aten.full.default,
693750
args=([1], mul_value),
@@ -702,8 +759,17 @@ def test_fuse_mul_into_quant(self) -> None:
702759
)
703760
builder.output([quant])
704761
original_graph = builder.get_graph_module()
762+
gm_before = copy.deepcopy(original_graph)
763+
705764
p = FuseMulTensorIntoQuantPass()
706-
converted_graph = cast(PassResult, p(original_graph)).graph_module
765+
result = cast(PassResult, p(original_graph))
766+
self.assertTrue(result.modified)
767+
converted_graph = result.graph_module
768+
769+
# Validate numerical accuracy
770+
validate_numerics(
771+
gm_before, converted_graph, (x_input,), "FuseMulTensorIntoQuantPass"
772+
)
707773

708774
# verify that the mul and full ops were removed
709775
self.check_op_counts(
@@ -723,12 +789,6 @@ def test_fuse_mul_into_quant(self) -> None:
723789
new_quant_scale = node.args[1]
724790
self.assertEqual(new_quant_scale, quant_scale / mul_value)
725791

726-
# verify the math is correct
727-
inp = torch.randn(4, 32, dtype=torch.float32)
728-
original_out = original_graph(inp)[0]
729-
new_out = converted_graph(inp)[0]
730-
assert torch.equal(original_out, new_out)
731-
732792
def test_fuse_then_transpose_pass(self) -> None:
733793
# Create a graph with full -> transpose -> permute -> view.
734794
builder = GraphBuilder()

0 commit comments

Comments
 (0)