diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index ea7e61da54a..a2b71c52288 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -44,13 +44,15 @@ def call(self, graph_module: torch.fx.GraphModule): op="call_function", target=exir_ops.edge.aten.mm.default ) for node in node_list: + mm_fake = get_first_fake_tensor(node) + # Unsqueeze input tensors to rank 3 for input_node in node.args: if not isinstance(input_node, Node): continue - shape = get_first_fake_tensor(input_node).shape - rank = len(shape) + input_fake = get_first_fake_tensor(input_node) + rank = len(input_fake.shape) if rank != 2: raise RuntimeError(f"Input tensor has rank {rank}, must be 2") @@ -65,6 +67,7 @@ def call(self, graph_module: torch.fx.GraphModule): input_node, # Input is node's original input 0, ) + unsqueeze_before.meta["val"] = input_fake.unsqueeze(0) node.replace_input_with(input_node, unsqueeze_before) # Replace mm node with bmm @@ -76,10 +79,15 @@ def call(self, graph_module: torch.fx.GraphModule): inherit_qparams=True, ) bmm_node.args = node.args + # Manually set output meta: same as mm but with batch dim. + # This avoids re-executing bmm on FakeTensors, which fails + # for quantized (int8/int16) inputs since aten.bmm only + # supports float32 FakeTensor propagation. + bmm_node.meta["val"] = mm_fake.unsqueeze(0) node.replace_all_uses_with(bmm_node) graph.erase_node(node) - # Unsqueeze output tensor to rank 3 + # Squeeze output tensor back to rank 2 with graph.inserting_after(bmm_node): squeeze_after = create_node( graph, @@ -91,6 +99,7 @@ def call(self, graph_module: torch.fx.GraphModule): bmm_node, [0], ) + squeeze_after.meta["val"] = mm_fake original_users = [ user for user in bmm_node.users if user != squeeze_after ] @@ -101,6 +110,5 @@ def call(self, graph_module: torch.fx.GraphModule): if modified_graph: graph_module.recompile() - graph_module = super().call(graph_module).graph_module return PassResult(graph_module, modified_graph)