Skip to content

Commit 8adf8e2

Browse files
committed
Improve layout handling for RoPE fusion
1 parent aa7a4c4 commit 8adf8e2

2 files changed

Lines changed: 119 additions & 8 deletions

File tree

backends/xnnpack/_passes/convert_to_rope.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ def _build_weights(
6666

6767
return weights
6868

69+
@staticmethod
70+
def _get_unsqueeze_dim(node: torch.fx.Node) -> int:
71+
"""Return the unsqueeze dim if node is an unsqueeze_copy, else -1."""
72+
if (
73+
node.op == "call_function"
74+
and node.target == exir_ops.edge.aten.unsqueeze_copy.default
75+
):
76+
return node.args[1]
77+
return -1
78+
79+
_BHSD_TO_BSHD_PERM = [0, 2, 1, 3]
80+
6981
@staticmethod
7082
def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
7183
"""If node is an unsqueeze_copy, return its input. Otherwise return node as-is."""
@@ -76,6 +88,33 @@ def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
7688
return node.args[0]
7789
return node
7890

91+
@staticmethod
92+
def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None:
93+
"""If node is a permute_copy that swaps dims 1 and 2, return its input."""
94+
if (
95+
node.op == "call_function"
96+
and node.target == exir_ops.edge.aten.permute_copy.default
97+
and list(node.args[1]) == [0, 2, 1, 3]
98+
):
99+
return node.args[0]
100+
return None
101+
102+
def _get_layout(self, cos_unsqueezed: torch.fx.Node) -> str | None:
103+
"""Determine the tensor layout from the cos unsqueeze dimension.
104+
105+
Returns "BSHD", "BHSD", or None if the layout cannot be determined.
106+
"""
107+
unsqueeze_dim = self._get_unsqueeze_dim(cos_unsqueezed)
108+
if unsqueeze_dim == -1:
109+
return None
110+
ndim = len(cos_unsqueezed.meta["val"].shape)
111+
normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim
112+
if normalized == 2:
113+
return "BSHD"
114+
if normalized == 1:
115+
return "BHSD"
116+
return None
117+
79118
def create_rope(
80119
self,
81120
graph_module: torch.fx.GraphModule,
@@ -90,9 +129,22 @@ def create_rope(
90129
sin_unsqueezed = match.placeholder_nodes[2]
91130
output_node = match.returning_nodes[0]
92131

93-
# Trace back through unsqueeze to get raw cos/sin for weight construction.
94-
# The pattern excludes unsqueeze ops (they're shared between q/k RoPE),
95-
# so the matched placeholders are the unsqueeze outputs.
132+
# xnn_define_rope expects NTHC (batch, tokens, heads, channels) input.
133+
# BSHD (unsqueeze_dim=2) maps directly to NTHC.
134+
# BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute
135+
# to recover the BSHD input, then re-permuting the output back to BHSD.
136+
layout = self._get_layout(cos_unsqueezed)
137+
if layout == "BSHD":
138+
rope_input = x_node
139+
elif layout == "BHSD":
140+
rope_input = self._trace_through_permute(x_node)
141+
if rope_input is None:
142+
logger.debug("Skipping RoPE fusion: BHSD but x is not a permute_copy")
143+
return
144+
else:
145+
logger.debug("Skipping RoPE fusion: unrecognized layout")
146+
return
147+
96148
cos_node = self._trace_through_unsqueeze(cos_unsqueezed)
97149
sin_node = self._trace_through_unsqueeze(sin_unsqueezed)
98150

@@ -102,12 +154,23 @@ def create_rope(
102154
rope_node = graph_module.graph.create_node(
103155
"call_function",
104156
torch.ops.xnnpack.rope.default,
105-
args=(x_node, weights),
157+
args=(rope_input, weights),
106158
)
107-
108-
rope_node.meta["val"] = torch.empty_like(x_node.meta["val"])
109-
110-
output_node.replace_all_uses_with(rope_node)
159+
rope_node.meta["val"] = torch.empty_like(rope_input.meta["val"])
160+
161+
if layout == "BHSD":
162+
permute_node = graph_module.graph.call_function(
163+
exir_ops.edge.aten.permute_copy.default,
164+
args=(rope_node, self._BHSD_TO_BSHD_PERM),
165+
)
166+
permute_node.meta["val"] = rope_node.meta["val"].permute(
167+
self._BHSD_TO_BSHD_PERM
168+
)
169+
result_node = permute_node
170+
else:
171+
result_node = rope_node
172+
173+
output_node.replace_all_uses_with(result_node)
111174
graph_module.graph.eliminate_dead_code()
112175

113176
# override

backends/xnnpack/test/ops/test_rope.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,51 @@ def test_fp32_rope_dynamic_seq_len(self):
6868
{0: None, 1: seq, 2: None},
6969
)
7070
self._test_rope(inputs, dynamic_shapes=dynamic_shapes)
71+
72+
class HFRopeBHSD(torch.nn.Module):
73+
"""HF-style RoPE with BHSD layout (transpose before/after RoPE)."""
74+
75+
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
76+
x = x.transpose(1, 2)
77+
cos = cos.unsqueeze(1)
78+
sin = sin.unsqueeze(1)
79+
x1 = x[..., : x.shape[-1] // 2]
80+
x2 = x[..., x.shape[-1] // 2 :]
81+
rot = torch.cat((-x2, x1), dim=-1)
82+
out = (x * cos) + (rot * sin)
83+
return out.transpose(1, 2)
84+
85+
def _test_rope_bhsd(self, inputs, dynamic_shapes=None):
86+
(
87+
Tester(self.HFRopeBHSD(), inputs, dynamic_shapes=dynamic_shapes)
88+
.export()
89+
.to_edge_transform_and_lower()
90+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
91+
.to_executorch()
92+
.serialize()
93+
.run_method_and_compare_outputs(inputs=inputs)
94+
)
95+
96+
def test_fp32_rope_bhsd(self):
97+
batch, seq_len, n_heads, head_dim = 1, 8, 4, 32
98+
cos, sin = _hf_freqs(seq_len, head_dim)
99+
inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin)
100+
self._test_rope_bhsd(inputs)
101+
102+
def test_fp32_rope_bhsd_large_head_dim(self):
103+
batch, seq_len, n_heads, head_dim = 1, 16, 8, 128
104+
cos, sin = _hf_freqs(seq_len, head_dim)
105+
inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin)
106+
self._test_rope_bhsd(inputs)
107+
108+
def test_fp32_rope_bhsd_dynamic_seq_len(self):
109+
batch, seq_len, n_heads, head_dim = 1, 8, 4, 32
110+
cos, sin = _hf_freqs(seq_len, head_dim)
111+
inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin)
112+
seq = Dim("seq", min=1, max=128)
113+
dynamic_shapes = (
114+
{0: None, 1: seq, 2: None, 3: None},
115+
{0: None, 1: seq, 2: None},
116+
{0: None, 1: seq, 2: None},
117+
)
118+
self._test_rope_bhsd(inputs, dynamic_shapes=dynamic_shapes)

0 commit comments

Comments
 (0)