@@ -66,6 +66,18 @@ def _build_weights(
6666
6767 return weights
6868
69+ @staticmethod
70+ def _get_unsqueeze_dim (node : torch .fx .Node ) -> int :
71+ """Return the unsqueeze dim if node is an unsqueeze_copy, else -1."""
72+ if (
73+ node .op == "call_function"
74+ and node .target == exir_ops .edge .aten .unsqueeze_copy .default
75+ ):
76+ return node .args [1 ]
77+ return - 1
78+
79+ _BHSD_TO_BSHD_PERM = [0 , 2 , 1 , 3 ]
80+
6981 @staticmethod
7082 def _trace_through_unsqueeze (node : torch .fx .Node ) -> torch .fx .Node :
7183 """If node is an unsqueeze_copy, return its input. Otherwise return node as-is."""
@@ -76,6 +88,33 @@ def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
7688 return node .args [0 ]
7789 return node
7890
91+ @staticmethod
92+ def _trace_through_permute (node : torch .fx .Node ) -> torch .fx .Node | None :
93+ """If node is a permute_copy that swaps dims 1 and 2, return its input."""
94+ if (
95+ node .op == "call_function"
96+ and node .target == exir_ops .edge .aten .permute_copy .default
97+ and list (node .args [1 ]) == [0 , 2 , 1 , 3 ]
98+ ):
99+ return node .args [0 ]
100+ return None
101+
102+ def _get_layout (self , cos_unsqueezed : torch .fx .Node ) -> str | None :
103+ """Determine the tensor layout from the cos unsqueeze dimension.
104+
105+ Returns "BSHD", "BHSD", or None if the layout cannot be determined.
106+ """
107+ unsqueeze_dim = self ._get_unsqueeze_dim (cos_unsqueezed )
108+ if unsqueeze_dim == - 1 :
109+ return None
110+ ndim = len (cos_unsqueezed .meta ["val" ].shape )
111+ normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim
112+ if normalized == 2 :
113+ return "BSHD"
114+ if normalized == 1 :
115+ return "BHSD"
116+ return None
117+
79118 def create_rope (
80119 self ,
81120 graph_module : torch .fx .GraphModule ,
@@ -90,9 +129,22 @@ def create_rope(
90129 sin_unsqueezed = match .placeholder_nodes [2 ]
91130 output_node = match .returning_nodes [0 ]
92131
93- # Trace back through unsqueeze to get raw cos/sin for weight construction.
94- # The pattern excludes unsqueeze ops (they're shared between q/k RoPE),
95- # so the matched placeholders are the unsqueeze outputs.
132+ # xnn_define_rope expects NTHC (batch, tokens, heads, channels) input.
133+ # BSHD (unsqueeze_dim=2) maps directly to NTHC.
134+ # BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute
135+ # to recover the BSHD input, then re-permuting the output back to BHSD.
136+ layout = self ._get_layout (cos_unsqueezed )
137+ if layout == "BSHD" :
138+ rope_input = x_node
139+ elif layout == "BHSD" :
140+ rope_input = self ._trace_through_permute (x_node )
141+ if rope_input is None :
142+ logger .debug ("Skipping RoPE fusion: BHSD but x is not a permute_copy" )
143+ return
144+ else :
145+ logger .debug ("Skipping RoPE fusion: unrecognized layout" )
146+ return
147+
96148 cos_node = self ._trace_through_unsqueeze (cos_unsqueezed )
97149 sin_node = self ._trace_through_unsqueeze (sin_unsqueezed )
98150
@@ -102,12 +154,23 @@ def create_rope(
102154 rope_node = graph_module .graph .create_node (
103155 "call_function" ,
104156 torch .ops .xnnpack .rope .default ,
105- args = (x_node , weights ),
157+ args = (rope_input , weights ),
106158 )
107-
108- rope_node .meta ["val" ] = torch .empty_like (x_node .meta ["val" ])
109-
110- output_node .replace_all_uses_with (rope_node )
159+ rope_node .meta ["val" ] = torch .empty_like (rope_input .meta ["val" ])
160+
161+ if layout == "BHSD" :
162+ permute_node = graph_module .graph .call_function (
163+ exir_ops .edge .aten .permute_copy .default ,
164+ args = (rope_node , self ._BHSD_TO_BSHD_PERM ),
165+ )
166+ permute_node .meta ["val" ] = rope_node .meta ["val" ].permute (
167+ self ._BHSD_TO_BSHD_PERM
168+ )
169+ result_node = permute_node
170+ else :
171+ result_node = rope_node
172+
173+ output_node .replace_all_uses_with (result_node )
111174 graph_module .graph .eliminate_dead_code ()
112175
113176 # override
0 commit comments