Skip to content

Commit 4f9fc5c

Browse files
committed
Improve rope matching soundness
1 parent 8adf8e2 commit 4f9fc5c

2 files changed

Lines changed: 55 additions & 1 deletion

File tree

backends/xnnpack/_passes/convert_to_rope.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,56 @@ def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
8888
return node.args[0]
8989
return node
9090

91+
@staticmethod
92+
def _find_trig_source(node: torch.fx.Node) -> torch.fx.Node | None:
93+
"""Walk backwards through unsqueeze_copy ops to find cos/sin op."""
94+
current = node
95+
for _ in range(10):
96+
if current.op != "call_function":
97+
return None
98+
if current.target in (
99+
exir_ops.edge.aten.cos.default,
100+
exir_ops.edge.aten.sin.default,
101+
):
102+
return current
103+
if current.target == exir_ops.edge.aten.unsqueeze_copy.default:
104+
current = current.args[0]
105+
continue
106+
return None
107+
return None
108+
109+
@classmethod
110+
def _is_doubled_cat(cls, trig_node: torch.fx.Node) -> bool:
111+
"""Check that a cos/sin node's input is cat(x, x) with identical args."""
112+
cat_node = trig_node.args[0]
113+
if (
114+
cat_node.op != "call_function"
115+
or cat_node.target != exir_ops.edge.aten.cat.default
116+
):
117+
return False
118+
tensors = cat_node.args[0]
119+
return len(tensors) == 2 and tensors[0] is tensors[1]
120+
121+
@classmethod
122+
def _has_doubled_freqs(
123+
cls,
124+
cos_unsqueezed: torch.fx.Node,
125+
sin_unsqueezed: torch.fx.Node,
126+
) -> bool:
127+
"""Verify that cos/sin frequencies are doubled (first half == second half).
128+
129+
Traces back through unsqueeze_copy ops to find the cos/sin producer,
130+
then verifies its input is cat(x, x) where both args are the same
131+
node — a structural proof that the first and second halves are identical.
132+
"""
133+
cos_trig = cls._find_trig_source(cos_unsqueezed)
134+
sin_trig = cls._find_trig_source(sin_unsqueezed)
135+
136+
if cos_trig is None or sin_trig is None:
137+
return False
138+
139+
return cls._is_doubled_cat(cos_trig) and cls._is_doubled_cat(sin_trig)
140+
91141
@staticmethod
92142
def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None:
93143
"""If node is a permute_copy that swaps dims 1 and 2, return its input."""
@@ -148,6 +198,10 @@ def create_rope(
148198
cos_node = self._trace_through_unsqueeze(cos_unsqueezed)
149199
sin_node = self._trace_through_unsqueeze(sin_unsqueezed)
150200

201+
if not self._has_doubled_freqs(cos_unsqueezed, sin_unsqueezed):
202+
logger.debug("Skipping RoPE fusion: cannot verify doubled frequencies")
203+
return
204+
151205
weights = self._build_weights(graph_module, cos_node, sin_node, output_node)
152206

153207
with graph_module.graph.inserting_before(output_node):

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
216216
if not self.check_common_constraints(node, ep):
217217
return False
218218

219-
num_tensors = len(node.all_input_nodes)
219+
num_tensors = len(node.args[0])
220220

221221
if not (num_tensors >= 2):
222222
why(

0 commit comments

Comments
 (0)