Skip to content

Commit b4a1625

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Fix ConvertMmToBmmPass for quantized (int8/int16) mm ops (pytorch#18974)
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
1 parent 0af4221 commit b4a1625

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

backends/arm/_passes/mm_to_bmm_pass.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ def call(self, graph_module: torch.fx.GraphModule):
4444
op="call_function", target=exir_ops.edge.aten.mm.default
4545
)
4646
for node in node_list:
47+
mm_fake = get_first_fake_tensor(node)
48+
4749
# Unsqueeze input tensors to rank 3
4850
for input_node in node.args:
4951
if not isinstance(input_node, Node):
5052
continue
5153

52-
shape = get_first_fake_tensor(input_node).shape
53-
rank = len(shape)
54+
input_fake = get_first_fake_tensor(input_node)
55+
rank = len(input_fake.shape)
5456
if rank != 2:
5557
raise RuntimeError(f"Input tensor has rank {rank}, must be 2")
5658

@@ -65,6 +67,7 @@ def call(self, graph_module: torch.fx.GraphModule):
6567
input_node, # Input is node's original input
6668
0,
6769
)
70+
unsqueeze_before.meta["val"] = input_fake.unsqueeze(0)
6871
node.replace_input_with(input_node, unsqueeze_before)
6972

7073
# Replace mm node with bmm
@@ -76,10 +79,15 @@ def call(self, graph_module: torch.fx.GraphModule):
7679
inherit_qparams=True,
7780
)
7881
bmm_node.args = node.args
82+
# Manually set output meta: same as mm but with batch dim.
83+
# This avoids re-executing bmm on FakeTensors, which fails
84+
# for quantized (int8/int16) inputs since aten.bmm only
85+
# supports float32 FakeTensor propagation.
86+
bmm_node.meta["val"] = mm_fake.unsqueeze(0)
7987
node.replace_all_uses_with(bmm_node)
8088
graph.erase_node(node)
8189

82-
# Unsqueeze output tensor to rank 3
90+
# Squeeze output tensor back to rank 2
8391
with graph.inserting_after(bmm_node):
8492
squeeze_after = create_node(
8593
graph,
@@ -91,6 +99,7 @@ def call(self, graph_module: torch.fx.GraphModule):
9199
bmm_node,
92100
[0],
93101
)
102+
squeeze_after.meta["val"] = mm_fake
94103
original_users = [
95104
user for user in bmm_node.users if user != squeeze_after
96105
]
@@ -101,6 +110,5 @@ def call(self, graph_module: torch.fx.GraphModule):
101110

102111
if modified_graph:
103112
graph_module.recompile()
104-
graph_module = super().call(graph_module).graph_module
105113

106114
return PassResult(graph_module, modified_graph)

0 commit comments

Comments
 (0)