@@ -44,13 +44,15 @@ def call(self, graph_module: torch.fx.GraphModule):
4444 op = "call_function" , target = exir_ops .edge .aten .mm .default
4545 )
4646 for node in node_list :
47+ mm_fake = get_first_fake_tensor (node )
48+
4749 # Unsqueeze input tensors to rank 3
4850 for input_node in node .args :
4951 if not isinstance (input_node , Node ):
5052 continue
5153
52- shape = get_first_fake_tensor (input_node ). shape
53- rank = len (shape )
54+ input_fake = get_first_fake_tensor (input_node )
55+ rank = len (input_fake . shape )
5456 if rank != 2 :
5557 raise RuntimeError (f"Input tensor has rank { rank } , must be 2" )
5658
@@ -65,6 +67,7 @@ def call(self, graph_module: torch.fx.GraphModule):
6567 input_node , # Input is node's original input
6668 0 ,
6769 )
70+ unsqueeze_before .meta ["val" ] = input_fake .unsqueeze (0 )
6871 node .replace_input_with (input_node , unsqueeze_before )
6972
7073 # Replace mm node with bmm
@@ -76,10 +79,15 @@ def call(self, graph_module: torch.fx.GraphModule):
7679 inherit_qparams = True ,
7780 )
7881 bmm_node .args = node .args
82+ # Manually set output meta: same as mm but with batch dim.
83+ # This avoids re-executing bmm on FakeTensors, which fails
84+ # for quantized (int8/int16) inputs since aten.bmm only
85+ # supports float32 FakeTensor propagation.
86+ bmm_node .meta ["val" ] = mm_fake .unsqueeze (0 )
7987 node .replace_all_uses_with (bmm_node )
8088 graph .erase_node (node )
8189
82- # Unsqueeze output tensor to rank 3
90+ # Squeeze output tensor back to rank 2
8391 with graph .inserting_after (bmm_node ):
8492 squeeze_after = create_node (
8593 graph ,
@@ -91,6 +99,7 @@ def call(self, graph_module: torch.fx.GraphModule):
9199 bmm_node ,
92100 [0 ],
93101 )
102+ squeeze_after .meta ["val" ] = mm_fake
94103 original_users = [
95104 user for user in bmm_node .users if user != squeeze_after
96105 ]
@@ -101,6 +110,5 @@ def call(self, graph_module: torch.fx.GraphModule):
101110
102111 if modified_graph :
103112 graph_module .recompile ()
104- graph_module = super ().call (graph_module ).graph_module
105113
106114 return PassResult (graph_module , modified_graph )
0 commit comments