@@ -88,6 +88,56 @@ def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
8888 return node .args [0 ]
8989 return node
9090
91+ @staticmethod
92+ def _find_trig_source (node : torch .fx .Node ) -> torch .fx .Node | None :
93+ """Walk backwards through unsqueeze_copy ops to find cos/sin op."""
94+ current = node
95+ for _ in range (10 ):
96+ if current .op != "call_function" :
97+ return None
98+ if current .target in (
99+ exir_ops .edge .aten .cos .default ,
100+ exir_ops .edge .aten .sin .default ,
101+ ):
102+ return current
103+ if current .target == exir_ops .edge .aten .unsqueeze_copy .default :
104+ current = current .args [0 ]
105+ continue
106+ return None
107+ return None
108+
109+ @classmethod
110+ def _is_doubled_cat (cls , trig_node : torch .fx .Node ) -> bool :
111+ """Check that a cos/sin node's input is cat(x, x) with identical args."""
112+ cat_node = trig_node .args [0 ]
113+ if (
114+ cat_node .op != "call_function"
115+ or cat_node .target != exir_ops .edge .aten .cat .default
116+ ):
117+ return False
118+ tensors = cat_node .args [0 ]
119+ return len (tensors ) == 2 and tensors [0 ] is tensors [1 ]
120+
121+ @classmethod
122+ def _has_doubled_freqs (
123+ cls ,
124+ cos_unsqueezed : torch .fx .Node ,
125+ sin_unsqueezed : torch .fx .Node ,
126+ ) -> bool :
127+ """Verify that cos/sin frequencies are doubled (first half == second half).
128+
129+ Traces back through unsqueeze_copy ops to find the cos/sin producer,
130+ then verifies its input is cat(x, x) where both args are the same
131+ node — a structural proof that the first and second halves are identical.
132+ """
133+ cos_trig = cls ._find_trig_source (cos_unsqueezed )
134+ sin_trig = cls ._find_trig_source (sin_unsqueezed )
135+
136+ if cos_trig is None or sin_trig is None :
137+ return False
138+
139+ return cls ._is_doubled_cat (cos_trig ) and cls ._is_doubled_cat (sin_trig )
140+
91141 @staticmethod
92142 def _trace_through_permute (node : torch .fx .Node ) -> torch .fx .Node | None :
93143 """If node is a permute_copy that swaps dims 1 and 2, return its input."""
@@ -148,6 +198,10 @@ def create_rope(
148198 cos_node = self ._trace_through_unsqueeze (cos_unsqueezed )
149199 sin_node = self ._trace_through_unsqueeze (sin_unsqueezed )
150200
201+ if not self ._has_doubled_freqs (cos_unsqueezed , sin_unsqueezed ):
202+ logger .debug ("Skipping RoPE fusion: cannot verify doubled frequencies" )
203+ return
204+
151205 weights = self ._build_weights (graph_module , cos_node , sin_node , output_node )
152206
153207 with graph_module .graph .inserting_before (output_node ):
0 commit comments