Skip to content

Commit 30d3860

Browse files
gramalingamCopilot
andauthored
Add remaining high-priority rewriter extended tests (#2899)
## Summary Add 22 extended tests across 8 test files covering the remaining high-priority gaps identified in the rewriter test gap analysis. ### New test files | File | Tests | Rule module | |------|-------|------------| | `_cast_constant_of_shape_extended_test.py` | 5 | `_cast_constant_of_shape` | | `_fuse_conv_affine_extended_test.py` | 5 | `_fuse_conv_affine` | | `_redundant_scatter_nd_extended_test.py` | 3 | `_redundant_scatter_nd` | | `_rotary_embedding_extended_test.py` (rules/fusion) | 3 | `PartialRotaryEmbedding23Fusion` | | `_gqa_extended_test.py` (rules/fusion) | 2 | `GroupQueryAttentionFusion` | | `gqa_extended_test.py` (ort_fusions) | 1 | `gqa` ORT fusion | | `gqa_packed_qkv_extended_test.py` (ort_fusions) | 1 | `gqa` packed QKV | | `cos_sin_cache_extended_test.py` (ort_fusions) | 2 | `cos_sin_cache` | ### Key patterns used - `@script()` API with `op.Constant(value=...)` for embedded constants - Symbolic dimensions (`"B"`, `"S"`) where appropriate - Numerical validation via ORT or ONNX reference implementation - Negative tests: non-constant operands, padded convolutions, missing prerequisites, mismatched boundaries - Script generator with traced-if for clean negative test model construction ### Modified existing file - `_rotary_embedding_models.py`: Added `_make_partial_rotary_script(mismatched)` generator for clean negative test construction Follow-up to PR #2896. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1911741 commit 30d3860

9 files changed

Lines changed: 1043 additions & 0 deletions

onnxscript/rewriter/models/_rotary_embedding_models.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,42 @@ def get_ort_inputs(self):
168168

169169
def partial_rotary_test_case():
170170
return _PartialRotaryTestCase()
171+
172+
173+
def _make_partial_rotary_script(mismatched: bool = False):
174+
"""Generate a partial rotary embedding script with matching or mismatched slice boundaries.
175+
176+
Args:
177+
mismatched: If True, the second slice starts at 33 instead of 32,
178+
creating a gap that should prevent PartialRotaryEmbedding fusion.
179+
"""
180+
181+
@script()
182+
def partial_rotary(position_ids, query):
183+
inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1]
184+
position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S]
185+
position_ids_3d_float = op.Cast(position_ids_3d, to=1)
186+
matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S]
187+
transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2]
188+
cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd]
189+
cos_3d = op.Cos(cat) # [B, S, rd]
190+
sin_3d = op.Sin(cat) # [B, S, rd]
191+
# Split the query for partial embedding
192+
to_embed = op.Slice(query, [0], [32], [3], [1])
193+
if mismatched:
194+
unembedded = op.Slice(query, [33], [9223372036854775807], [3], [1])
195+
else:
196+
unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1])
197+
cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd]
198+
sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd]
199+
to_embed_times_cos = op.Mul(to_embed, cos_4d)
200+
to_embed_x = op.Slice(to_embed, [0], [16], [3], [1])
201+
to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1])
202+
minus_to_embed_y = op.Neg(to_embed_y)
203+
to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1)
204+
to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d)
205+
embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin)
206+
final = op.Concat(embedded, unembedded, axis=-1)
207+
return final
208+
209+
return partial_rotary
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Extended tests for cos_sin_cache fusion.
5+
6+
Adds coverage for: non-constant inv_freq (negative — dynamic inv_freq prevents cache precomputation).
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import unittest
12+
13+
import numpy
14+
import onnx_ir as ir
15+
16+
from onnxscript import FLOAT, INT64, optimizer, script, values
17+
from onnxscript import opset18 as op
18+
from onnxscript.rewriter.ort_fusions._test_utils import ort_run
19+
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
20+
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
21+
22+
msft_op = values.Opset("com.microsoft", 1)
23+
24+
25+
class CosSinCacheExtendedTest(unittest.TestCase):
26+
"""Extended tests for cos_sin_cache fusion."""
27+
28+
def test_non_constant_inv_freq_no_cache_fusion(self):
29+
"""When inv_freq is a graph input (not constant), cos_sin_cache should not fuse.
30+
31+
The cos_sin_cache fusion relies on inv_freq being a constant to precompute
32+
the cos/sin lookup table. When inv_freq is dynamic, no fusion should apply.
33+
"""
34+
35+
@script()
36+
def model_with_dynamic_inv_freq(
37+
x: FLOAT[1, 4, 8, 8],
38+
position_ids: INT64[1, 8],
39+
inv_freq: FLOAT[1, 4, 1],
40+
) -> FLOAT[1, 4, 8, 8]:
41+
# inv_freq is a graph input, not a constant
42+
position_ids_expanded = op.Unsqueeze(position_ids, [1])
43+
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
44+
freqs = op.MatMul(inv_freq, position_ids_float)
45+
freqs_t = op.Transpose(freqs, perm=[0, 2, 1])
46+
emb = op.Concat(freqs_t, freqs_t, axis=-1)
47+
cos = op.Cos(emb)
48+
sin = op.Sin(emb)
49+
cos_4d = op.Unsqueeze(cos, [1])
50+
sin_4d = op.Unsqueeze(sin, [1])
51+
52+
x1 = op.Slice(x, [0], [4], [3], [1])
53+
x2 = op.Slice(x, [4], [8], [3], [1])
54+
minus_x2 = op.Neg(x2)
55+
rotated_x = op.Concat(minus_x2, x1, axis=-1)
56+
result = op.Add(x * cos_4d, rotated_x * sin_4d)
57+
return result
58+
59+
model_proto = model_with_dynamic_inv_freq.to_model_proto()
60+
model = ir.serde.deserialize_model(model_proto)
61+
optimizer.optimize(model)
62+
63+
# Rotary embedding fusion should still work
64+
re_count = fuse_rotary_embedding(model)
65+
self.assertGreater(re_count, 0, "RotaryEmbedding fusion should succeed.")
66+
67+
# cos_sin_cache fusion should NOT work because inv_freq is not constant
68+
cache_count = fuse_cos_sin_cache(model)
69+
self.assertEqual(
70+
cache_count, 0, "cos_sin_cache should NOT fuse with dynamic inv_freq."
71+
)
72+
73+
def test_constant_inv_freq_does_fuse(self):
74+
"""Sanity check: constant inv_freq allows cos_sin_cache fusion."""
75+
76+
@script()
77+
def model_with_const_inv_freq(
78+
x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]
79+
) -> FLOAT[1, 4, 8, 8]:
80+
inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0])
81+
inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2])
82+
position_ids_expanded = op.Unsqueeze(position_ids, [1])
83+
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
84+
freqs = op.MatMul(inv_freq_3d, position_ids_float)
85+
freqs_t = op.Transpose(freqs, perm=[0, 2, 1])
86+
emb = op.Concat(freqs_t, freqs_t, axis=-1)
87+
cos = op.Cos(emb)
88+
sin = op.Sin(emb)
89+
cos_4d = op.Unsqueeze(cos, [1])
90+
sin_4d = op.Unsqueeze(sin, [1])
91+
92+
x1 = op.Slice(x, [0], [4], [3], [1])
93+
x2 = op.Slice(x, [4], [8], [3], [1])
94+
minus_x2 = op.Neg(x2)
95+
rotated_x = op.Concat(minus_x2, x1, axis=-1)
96+
result = op.Add(x * cos_4d, rotated_x * sin_4d)
97+
return result
98+
99+
model_proto = model_with_const_inv_freq.to_model_proto()
100+
model = ir.serde.deserialize_model(model_proto)
101+
optimizer.optimize(model)
102+
103+
inputs = {
104+
"x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32),
105+
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8),
106+
}
107+
original_outputs = ort_run("original", model, inputs)
108+
109+
re_count = fuse_rotary_embedding(model)
110+
self.assertGreater(re_count, 0)
111+
cache_count = fuse_cos_sin_cache(model)
112+
self.assertGreater(cache_count, 0, "cos_sin_cache should fuse with constant inv_freq.")
113+
114+
# Numerical validation
115+
fused_outputs = ort_run("fused", model, inputs)
116+
numpy.testing.assert_allclose(
117+
original_outputs[0], fused_outputs[0], rtol=1e-5, atol=1e-5
118+
)
119+
120+
121+
if __name__ == "__main__":
122+
unittest.main()
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Extended tests for GQA ort_fusions fusion.
5+
6+
Adds coverage for: missing RotaryEmbedding (negative — no GQA fusion).
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import math
12+
import unittest
13+
14+
import onnx
15+
import onnx_ir as ir
16+
import onnx_ir.passes.common.shape_inference as shape_inference
17+
18+
from onnxscript import FLOAT, optimizer, script, values
19+
from onnxscript import opset18 as op
20+
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
21+
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
22+
23+
msft_op = values.Opset("com.microsoft", 1)
24+
25+
26+
class GQAOrtFusionExtendedTest(unittest.TestCase):
27+
"""Extended negative test for GQA ort_fusion."""
28+
29+
def test_no_rotary_embedding_no_gqa_fusion(self):
30+
"""GQA source pattern without RotaryEmbedding → SDPA may fuse but GQA should NOT.
31+
32+
The ort_fusions GQA rule requires RotaryEmbedding nodes to fuse the full GQA
33+
pattern (query/key rotary + kv-cache concat + expand + SDPA → GQA).
34+
Without rotary embedding, the GQA-specific fusion should not trigger.
35+
"""
36+
head_size = 16
37+
num_heads = 20
38+
kv_num_heads = 10
39+
hidden_size = head_size * num_heads
40+
kv_hidden_size = head_size * kv_num_heads
41+
num_groups = num_heads // kv_num_heads
42+
43+
H = [num_heads]
44+
Hkv = [kv_num_heads]
45+
Dh = [head_size]
46+
G = [num_groups]
47+
minus_1 = [-1]
48+
49+
scale_factor = math.sqrt(math.sqrt(head_size))
50+
51+
@script()
52+
def gqa_no_rotary(query, key, value, past_key, past_value):
53+
B = op.Shape(query, start=0, end=1)
54+
S = op.Shape(query, start=1, end=2)
55+
past_seq_length = op.Shape(past_key, start=2, end=3)
56+
total_seq_length = op.Add(past_seq_length, S)
57+
58+
shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0)
59+
shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0)
60+
shape_BSD = op.Concat(B, S, minus_1, axis=0)
61+
shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0)
62+
shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0)
63+
64+
query_BSHDh = op.Reshape(query, shape_BSHDh)
65+
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
66+
67+
key_BSHkvDh = op.Reshape(key, shape_BSHkvDh)
68+
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
69+
70+
value_BSHkvDh = op.Reshape(value, shape_BSHkvDh)
71+
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])
72+
73+
# NO RotaryEmbedding here — just use key/query directly
74+
key_seq = op.Concat(past_key, key_BHkvSDh, axis=-2)
75+
value_seq = op.Concat(past_value, value_BHkvSDh, axis=-2)
76+
77+
key_unsq = op.Unsqueeze(key_seq, [2])
78+
key_exp = op.Expand(key_unsq, shape_BHkvGSDh)
79+
key_rsh = op.Reshape(key_exp, shape_BHSDh)
80+
81+
value_unsq = op.Unsqueeze(value_seq, [2])
82+
value_exp = op.Expand(value_unsq, shape_BHkvGSDh)
83+
value_rsh = op.Reshape(value_exp, shape_BHSDh)
84+
85+
# Attention
86+
key_t = op.Transpose(key_rsh, perm=[0, 1, 3, 2])
87+
divisor = op.Constant(value_float=scale_factor)
88+
scaled_q = op.Div(query_BHSDh, divisor)
89+
scaled_k = op.Div(key_t, divisor)
90+
score = op.MatMul(scaled_q, scaled_k)
91+
weight = op.Softmax(score, axis=-1)
92+
attn = op.MatMul(weight, value_rsh)
93+
94+
attn_t = op.Transpose(attn, perm=[0, 2, 1, 3])
95+
attn_out = op.Reshape(attn_t, shape_BSD)
96+
97+
return attn_out, key_seq, value_seq
98+
99+
D = hidden_size
100+
Dkv_val = kv_hidden_size
101+
Dh_val = head_size
102+
Hkv_val = kv_num_heads
103+
104+
input_types = (
105+
FLOAT["B", "S", D],
106+
FLOAT["B", "S", Dkv_val],
107+
FLOAT["B", "S", Dkv_val],
108+
FLOAT["B", Hkv_val, "P", Dh_val],
109+
FLOAT["B", Hkv_val, "P", Dh_val],
110+
)
111+
output_types = (
112+
FLOAT["B", "S", D],
113+
FLOAT["B", Hkv_val, "T", Dh_val],
114+
FLOAT["B", Hkv_val, "T", Dh_val],
115+
)
116+
117+
source_model = gqa_no_rotary.to_model_proto(
118+
input_types=input_types,
119+
output_types=output_types,
120+
)
121+
122+
# Add value_info for shapes needed by fusion
123+
query_BSHDh_vi = onnx.helper.make_tensor_value_info(
124+
"query_BSHDh", onnx.TensorProto.FLOAT, ["B", "S", num_heads, head_size]
125+
)
126+
key_BSHkvDh_vi = onnx.helper.make_tensor_value_info(
127+
"key_BSHkvDh", onnx.TensorProto.FLOAT, ["B", "S", kv_num_heads, head_size]
128+
)
129+
source_model.graph.value_info.extend([query_BSHDh_vi, key_BSHkvDh_vi])
130+
131+
model = ir.serde.from_proto(source_model)
132+
inferred = shape_inference.infer_shapes(model)
133+
optimizer.optimize(inferred)
134+
135+
# SDPA might fuse, but GQA should not (no RotaryEmbedding)
136+
fuse_sdpa(inferred, debug=False)
137+
count = fuse_gqa(inferred, debug=False)
138+
self.assertEqual(count, 0, "GQA fusion should NOT succeed without RotaryEmbedding.")
139+
140+
141+
if __name__ == "__main__":
142+
unittest.main()

0 commit comments

Comments
 (0)