Skip to content

Commit a1ebbcc

Browse files
authored
Fix ConvertMmToBmmPass for quantized (int8/int16) mm ops (#18974)
Differential Revision: D99857137 Pull Request resolved: #18974
1 parent 1c11601 commit a1ebbcc

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

backends/arm/_passes/mm_to_bmm_pass.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ def call(self, graph_module: torch.fx.GraphModule):
4444
op="call_function", target=exir_ops.edge.aten.mm.default
4545
)
4646
for node in node_list:
47+
mm_fake = get_first_fake_tensor(node)
48+
4749
# Unsqueeze input tensors to rank 3
4850
for input_node in node.args:
4951
if not isinstance(input_node, Node):
5052
continue
5153

52-
shape = get_first_fake_tensor(input_node).shape
53-
rank = len(shape)
54+
input_fake = get_first_fake_tensor(input_node)
55+
rank = len(input_fake.shape)
5456
if rank != 2:
5557
raise RuntimeError(f"Input tensor has rank {rank}, must be 2")
5658

@@ -65,6 +67,7 @@ def call(self, graph_module: torch.fx.GraphModule):
6567
input_node, # Input is node's original input
6668
0,
6769
)
70+
unsqueeze_before.meta["val"] = input_fake.unsqueeze(0)
6871
node.replace_input_with(input_node, unsqueeze_before)
6972

7073
# Replace mm node with bmm
@@ -76,10 +79,15 @@ def call(self, graph_module: torch.fx.GraphModule):
7679
inherit_qparams=True,
7780
)
7881
bmm_node.args = node.args
82+
# Manually set output meta: same as mm but with batch dim.
83+
# This avoids re-executing bmm on FakeTensors, which fails
84+
# for quantized (int8/int16) inputs since aten.bmm only
85+
# supports float32 FakeTensor propagation.
86+
bmm_node.meta["val"] = mm_fake.unsqueeze(0)
7987
node.replace_all_uses_with(bmm_node)
8088
graph.erase_node(node)
8189

82-
# Unsqueeze output tensor to rank 3
90+
# Squeeze output tensor back to rank 2
8391
with graph.inserting_after(bmm_node):
8492
squeeze_after = create_node(
8593
graph,
@@ -91,6 +99,7 @@ def call(self, graph_module: torch.fx.GraphModule):
9199
bmm_node,
92100
[0],
93101
)
102+
squeeze_after.meta["val"] = mm_fake
94103
original_users = [
95104
user for user in bmm_node.users if user != squeeze_after
96105
]
@@ -101,6 +110,5 @@ def call(self, graph_module: torch.fx.GraphModule):
101110

102111
if modified_graph:
103112
graph_module.recompile()
104-
graph_module = super().call(graph_module).graph_module
105113

106114
return PassResult(graph_module, modified_graph)

0 commit comments

Comments
 (0)