diff --git a/onnxscript/rewriter/models/_rotary_embedding_models.py b/onnxscript/rewriter/models/_rotary_embedding_models.py index 3709cd04f7..8b981837f4 100644 --- a/onnxscript/rewriter/models/_rotary_embedding_models.py +++ b/onnxscript/rewriter/models/_rotary_embedding_models.py @@ -168,3 +168,42 @@ def get_ort_inputs(self): def partial_rotary_test_case(): return _PartialRotaryTestCase() + + +def _make_partial_rotary_script(mismatched: bool = False): + """Generate a partial rotary embedding script with matching or mismatched slice boundaries. + + Args: + mismatched: If True, the second slice starts at 33 instead of 32, + creating a gap that should prevent PartialRotaryEmbedding fusion. + """ + + @script() + def partial_rotary(position_ids, query): + inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1] + position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S] + position_ids_3d_float = op.Cast(position_ids_3d, to=1) + matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S] + transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2] + cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd] + cos_3d = op.Cos(cat) # [B, S, rd] + sin_3d = op.Sin(cat) # [B, S, rd] + # Split the query for partial embedding + to_embed = op.Slice(query, [0], [32], [3], [1]) + if mismatched: + unembedded = op.Slice(query, [33], [9223372036854775807], [3], [1]) + else: + unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) + cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd] + sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd] + to_embed_times_cos = op.Mul(to_embed, cos_4d) + to_embed_x = op.Slice(to_embed, [0], [16], [3], [1]) + to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1]) + minus_to_embed_y = op.Neg(to_embed_y) + to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1) + to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d) + embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin) + final = op.Concat(embedded, unembedded, axis=-1) + return final + + return partial_rotary diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache_extended_test.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache_extended_test.py new file mode 100644 index 0000000000..f948a5b765 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache_extended_test.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for cos_sin_cache fusion. + +Adds coverage for: non-constant inv_freq (negative — dynamic inv_freq prevents cache precomputation). +""" + +from __future__ import annotations + +import unittest + +import numpy +import onnx_ir as ir + +from onnxscript import FLOAT, INT64, optimizer, script, values +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions._test_utils import ort_run +from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding + +msft_op = values.Opset("com.microsoft", 1) + + +class CosSinCacheExtendedTest(unittest.TestCase): + """Extended tests for cos_sin_cache fusion.""" + + def test_non_constant_inv_freq_no_cache_fusion(self): + """When inv_freq is a graph input (not constant), cos_sin_cache should not fuse. + + The cos_sin_cache fusion relies on inv_freq being a constant to precompute + the cos/sin lookup table. When inv_freq is dynamic, no fusion should apply. + """ + + @script() + def model_with_dynamic_inv_freq( + x: FLOAT[1, 4, 8, 8], + position_ids: INT64[1, 8], + inv_freq: FLOAT[1, 4, 1], + ) -> FLOAT[1, 4, 8, 8]: + # inv_freq is a graph input, not a constant + position_ids_expanded = op.Unsqueeze(position_ids, [1]) + position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + freqs = op.MatMul(inv_freq, position_ids_float) + freqs_t = op.Transpose(freqs, perm=[0, 2, 1]) + emb = op.Concat(freqs_t, freqs_t, axis=-1) + cos = op.Cos(emb) + sin = op.Sin(emb) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) + + x1 = op.Slice(x, [0], [4], [3], [1]) + x2 = op.Slice(x, [4], [8], [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + result = op.Add(x * cos_4d, rotated_x * sin_4d) + return result + + model_proto = model_with_dynamic_inv_freq.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + optimizer.optimize(model) + + # Rotary embedding fusion should still work + re_count = fuse_rotary_embedding(model) + self.assertGreater(re_count, 0, "RotaryEmbedding fusion should succeed.") + + # cos_sin_cache fusion should NOT work because inv_freq is not constant + cache_count = fuse_cos_sin_cache(model) + self.assertEqual( + cache_count, 0, "cos_sin_cache should NOT fuse with dynamic inv_freq." + ) + + def test_constant_inv_freq_does_fuse(self): + """Sanity check: constant inv_freq allows cos_sin_cache fusion.""" + + @script() + def model_with_const_inv_freq( + x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8] + ) -> FLOAT[1, 4, 8, 8]: + inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0]) + inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2]) + position_ids_expanded = op.Unsqueeze(position_ids, [1]) + position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + freqs = op.MatMul(inv_freq_3d, position_ids_float) + freqs_t = op.Transpose(freqs, perm=[0, 2, 1]) + emb = op.Concat(freqs_t, freqs_t, axis=-1) + cos = op.Cos(emb) + sin = op.Sin(emb) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) + + x1 = op.Slice(x, [0], [4], [3], [1]) + x2 = op.Slice(x, [4], [8], [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + result = op.Add(x * cos_4d, rotated_x * sin_4d) + return result + + model_proto = model_with_const_inv_freq.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + optimizer.optimize(model) + + inputs = { + "x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32), + "position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8), + } + original_outputs = ort_run("original", model, inputs) + + re_count = fuse_rotary_embedding(model) + self.assertGreater(re_count, 0) + cache_count = fuse_cos_sin_cache(model) + self.assertGreater(cache_count, 0, "cos_sin_cache should fuse with constant inv_freq.") + + # Numerical validation + fused_outputs = ort_run("fused", model, inputs) + numpy.testing.assert_allclose( + original_outputs[0], fused_outputs[0], rtol=1e-5, atol=1e-5 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/gqa_extended_test.py b/onnxscript/rewriter/ort_fusions/gqa_extended_test.py new file mode 100644 index 0000000000..1d073f2db6 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_extended_test.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for GQA ort_fusions fusion. + +Adds coverage for: missing RotaryEmbedding (negative — no GQA fusion). +""" + +from __future__ import annotations + +import math +import unittest + +import onnx +import onnx_ir as ir +import onnx_ir.passes.common.shape_inference as shape_inference + +from onnxscript import FLOAT, optimizer, script, values +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa +from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa + +msft_op = values.Opset("com.microsoft", 1) + + +class GQAOrtFusionExtendedTest(unittest.TestCase): + """Extended negative test for GQA ort_fusion.""" + + def test_no_rotary_embedding_no_gqa_fusion(self): + """GQA source pattern without RotaryEmbedding → SDPA may fuse but GQA should NOT. + + The ort_fusions GQA rule requires RotaryEmbedding nodes to fuse the full GQA + pattern (query/key rotary + kv-cache concat + expand + SDPA → GQA). + Without rotary embedding, the GQA-specific fusion should not trigger. + """ + head_size = 16 + num_heads = 20 + kv_num_heads = 10 + hidden_size = head_size * num_heads + kv_hidden_size = head_size * kv_num_heads + num_groups = num_heads // kv_num_heads + + H = [num_heads] + Hkv = [kv_num_heads] + Dh = [head_size] + G = [num_groups] + minus_1 = [-1] + + scale_factor = math.sqrt(math.sqrt(head_size)) + + @script() + def gqa_no_rotary(query, key, value, past_key, past_value): + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) + shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) + + query_BSHDh = op.Reshape(query, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + # NO RotaryEmbedding here — just use key/query directly + key_seq = op.Concat(past_key, key_BHkvSDh, axis=-2) + value_seq = op.Concat(past_value, value_BHkvSDh, axis=-2) + + key_unsq = op.Unsqueeze(key_seq, [2]) + key_exp = op.Expand(key_unsq, shape_BHkvGSDh) + key_rsh = op.Reshape(key_exp, shape_BHSDh) + + value_unsq = op.Unsqueeze(value_seq, [2]) + value_exp = op.Expand(value_unsq, shape_BHkvGSDh) + value_rsh = op.Reshape(value_exp, shape_BHSDh) + + # Attention + key_t = op.Transpose(key_rsh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_q = op.Div(query_BHSDh, divisor) + scaled_k = op.Div(key_t, divisor) + score = op.MatMul(scaled_q, scaled_k) + weight = op.Softmax(score, axis=-1) + attn = op.MatMul(weight, value_rsh) + + attn_t = op.Transpose(attn, perm=[0, 2, 1, 3]) + attn_out = op.Reshape(attn_t, shape_BSD) + + return attn_out, key_seq, value_seq + + D = hidden_size + Dkv_val = kv_hidden_size + Dh_val = head_size + Hkv_val = kv_num_heads + + input_types = ( + FLOAT["B", "S", D], + FLOAT["B", "S", Dkv_val], + FLOAT["B", "S", Dkv_val], + FLOAT["B", Hkv_val, "P", Dh_val], + FLOAT["B", Hkv_val, "P", Dh_val], + ) + output_types = ( + FLOAT["B", "S", D], + FLOAT["B", Hkv_val, "T", Dh_val], + FLOAT["B", Hkv_val, "T", Dh_val], + ) + + source_model = gqa_no_rotary.to_model_proto( + input_types=input_types, + output_types=output_types, + ) + + # Add value_info for shapes needed by fusion + query_BSHDh_vi = onnx.helper.make_tensor_value_info( + "query_BSHDh", onnx.TensorProto.FLOAT, ["B", "S", num_heads, head_size] + ) + key_BSHkvDh_vi = onnx.helper.make_tensor_value_info( + "key_BSHkvDh", onnx.TensorProto.FLOAT, ["B", "S", kv_num_heads, head_size] + ) + source_model.graph.value_info.extend([query_BSHDh_vi, key_BSHkvDh_vi]) + + model = ir.serde.from_proto(source_model) + inferred = shape_inference.infer_shapes(model) + optimizer.optimize(inferred) + + # SDPA might fuse, but GQA should not (no RotaryEmbedding) + fuse_sdpa(inferred, debug=False) + count = fuse_gqa(inferred, debug=False) + self.assertEqual(count, 0, "GQA fusion should NOT succeed without RotaryEmbedding.") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv_extended_test.py b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv_extended_test.py new file mode 100644 index 0000000000..862cef7cec --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv_extended_test.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for GQA packed QKV fusion. + +Adds coverage for: misaligned slice boundaries (negative). +""" + +from __future__ import annotations + +import unittest + +import onnx_ir as ir +import onnx_ir.passes.common.shape_inference as shape_inference + +from onnxscript import FLOAT, INT32, optimizer, script, values +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa + +msft_op = values.Opset("com.microsoft", 1) + + +class PackedQKVExtendedTest(unittest.TestCase): + """Extended tests for GQA packed QKV fusion.""" + + def test_misaligned_slice_boundaries_no_fusion(self): + """Slice boundaries don't align with head sizes → should NOT fuse. + + With q_num_heads=20, kv_num_heads=10, head_size=16: + hidden_size = 16*(20 + 2*10) = 640 + q: [0, 320), k: [320, 480), v: [480, 640) + We intentionally misalign k slice to [300, 460) instead. + """ + Hq = 20 + Hkv = 10 + Dh = 16 + D = Dh * (Hq + 2 * Hkv) # 640 + + @script() + def gqa_misaligned( + packed_qkv, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos, + sin, + ): + # Correct q slice + query = op.Slice(packed_qkv, [0], [320], [2], [1]) + # WRONG: misaligned key slice (should be [320, 480)) + key = op.Slice(packed_qkv, [300], [460], [2], [1]) + # Value slice based on wrong offset + value = op.Slice(packed_qkv, [460], [640], [2], [1]) + + attn, pk, pv = msft_op.GroupQueryAttention( + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos, + sin, + num_heads=Hq, + kv_num_heads=Hkv, + do_rotary=1, + rotary_interleaved=0, + ) + return attn, pk, pv + + input_types = ( + FLOAT["B", "S", D], + FLOAT["B", Hkv, "P", Dh], + FLOAT["B", Hkv, "P", Dh], + INT32["B"], + INT32[1], + FLOAT["max_seqlen", Dh // 2], + FLOAT["max_seqlen", Dh // 2], + ) + output_types = ( + FLOAT["B", "S", D], + FLOAT["B", Hkv, "T", Dh], + FLOAT["B", Hkv, "T", Dh], + ) + + model_proto = gqa_misaligned.to_model_proto( + input_types=input_types, output_types=output_types + ) + model = ir.serde.from_proto(model_proto) + inferred = shape_inference.infer_shapes(model) + optimizer.optimize(inferred) + + count = fuse_qkv_gqa(inferred, debug=False) + self.assertEqual(count, 0, "Should NOT fuse with misaligned slice boundaries.") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_cast_constant_of_shape_extended_test.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_extended_test.py new file mode 100644 index 0000000000..8a5896e3ab --- /dev/null +++ b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_extended_test.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for CastConstantOfShape rules. + +Adds coverage for: additional target dtypes (int32, float64, bfloat16), +the no-value variant with int target, and a same-dtype (identity) cast. +""" + +from __future__ import annotations + +import unittest + +import onnx.parser +import onnx_ir as ir + +from onnxscript.rewriter.rules.common import _cast_constant_of_shape + + +class CastConstantOfShapeExtendedTest(unittest.TestCase): + """Extended tests for CastConstantOfShape rewrite rules.""" + + def _apply(self, model_text: str, expected_count: int = 1): + """Parse ONNX text, apply rule, and return the model.""" + model_proto = onnx.parser.parse_model(model_text) + model = ir.serde.deserialize_model(model_proto) + count = _cast_constant_of_shape.rules.apply_to_model(model) + self.assertEqual(count, expected_count) + return model + + # --- Positive: additional dtypes --- + + def test_cast_to_int32(self): + """ConstantOfShape(float) → Cast(to=INT32) fuses.""" + model = self._apply( + """ + + agraph (int64[2] input_x) => (int32[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + self.assertEqual(len(model.graph), 1) + # dtype 6 = INT32 + self.assertEqual(model.graph[0].attributes["value"].value.dtype, 6) + + def test_cast_to_float64(self): + """ConstantOfShape(float) → Cast(to=DOUBLE) fuses.""" + model = self._apply( + """ + + agraph (int64[2] input_x) => (double[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + self.assertEqual(len(model.graph), 1) + # dtype 11 = DOUBLE + self.assertEqual(model.graph[0].attributes["value"].value.dtype, 11) + + def test_cast_to_bfloat16(self): + """ConstantOfShape(float) → Cast(to=BFLOAT16) fuses.""" + model = self._apply( + """ + + agraph (int64[2] input_x) => (bfloat16[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + self.assertEqual(len(model.graph), 1) + # dtype 16 = BFLOAT16 + self.assertEqual(model.graph[0].attributes["value"].value.dtype, 16) + + def test_without_value_cast_to_int32(self): + """ConstantOfShape (no value) → Cast(to=INT32) fuses with zero default.""" + model = self._apply( + """ + + agraph (int64[2] input_x) => (int32[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph[0].attributes["value"].value.dtype, 6) + + def test_same_dtype_cast_still_fuses(self): + """Cast to same dtype as ConstantOfShape is still a valid fusion (removes Cast).""" + model = self._apply( + """ + + agraph (int64[2] input_x) => (float[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + # Cast should be removed (fused), leaving only ConstantOfShape + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph[0].op_type, "ConstantOfShape") + self.assertEqual(model.graph[0].attributes["value"].value.dtype, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine_extended_test.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine_extended_test.py new file mode 100644 index 0000000000..2d301b2808 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine_extended_test.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for ConvAffineFusion and AffineConvFusion rules. + +Adds coverage for: non-constant weight (negative), non-scalar scale (negative), +padded pre-conv affine (negative), non-constant bias (negative), positive with +numerical validation. +""" + +from __future__ import annotations + +import unittest + +import numpy as np +import onnx_ir as ir + +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.rewriter import rewrite, testing +from onnxscript.rewriter.rules.common import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) + +# Constants used in @script() models +_W_ONES = np.ones((3, 3, 3, 3), dtype=np.float32) +_B_ONES = np.ones((3,), dtype=np.float32) + + +class FuseConvAffineExtendedTest(unittest.TestCase): + """Extended tests for ConvAffineFusion and AffineConvFusion.""" + + def _clone(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + # --- Negative: non-constant weight --- + + def test_conv_affine_non_constant_weight_no_fusion(self): + """Non-constant weight (graph input) -> check rejects (w must be constant).""" + + @script() + def model_fn( + x: FLOAT[1, 3, 32, 32], + w: FLOAT[3, 3, 3, 3], + ) -> FLOAT[1, 3, 32, 32]: + b = op.Constant(value=_B_ONES) + scale = op.Constant(value_float=2.0) + offset = op.Constant(value_float=3.0) + conv_out = op.Conv(x, w, b, pads=[1, 1, 1, 1]) + return (conv_out * scale) + offset + + model_proto = model_fn.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rewritten = self._clone(model) + rewritten = rewrite(rewritten, pattern_rewrite_rules=[conv_affine_fusion_rule]) + self.assertEqual(model.graph.num_nodes(), rewritten.graph.num_nodes()) + + # --- Negative: non-scalar scale --- + + def test_conv_affine_non_scalar_scale_no_fusion(self): + """Vector scale -> check rejects (scale must be scalar).""" + + @script() + def model_fn(x: FLOAT[1, 3, 32, 32]) -> FLOAT[1, 3, 32, 32]: + w = op.Constant(value=_W_ONES) + b = op.Constant(value=_B_ONES) + scale = op.Constant(value_floats=[2.0, 3.0, 4.0]) # vector, not scalar + offset = op.Constant(value_float=3.0) + conv_out = op.Conv(x, w, b, pads=[1, 1, 1, 1]) + return (conv_out * scale) + offset + + model_proto = model_fn.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rewritten = self._clone(model) + rewritten = rewrite(rewritten, pattern_rewrite_rules=[conv_affine_fusion_rule]) + self.assertEqual(model.graph.num_nodes(), rewritten.graph.num_nodes()) + + # --- Negative: non-constant bias --- + + def test_conv_affine_non_constant_bias_no_fusion(self): + """Non-constant bias (graph input) -> check rejects (b must be constant).""" + + @script() + def model_fn( + x: FLOAT[1, 3, 32, 32], + b: FLOAT[3], + ) -> FLOAT[1, 3, 32, 32]: + w = op.Constant(value=_W_ONES) + scale = op.Constant(value_float=2.0) + offset = op.Constant(value_float=3.0) + conv_out = op.Conv(x, w, b, pads=[1, 1, 1, 1]) + return (conv_out * scale) + offset + + model_proto = model_fn.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rewritten = self._clone(model) + rewritten = rewrite(rewritten, pattern_rewrite_rules=[conv_affine_fusion_rule]) + self.assertEqual(model.graph.num_nodes(), rewritten.graph.num_nodes()) + + # --- Negative: pre-conv affine with padding --- + + def test_affine_conv_with_padding_no_fusion(self): + """Pre-conv affine + padded Conv -> AffineConvFusion must NOT match. + + AffineConvFusion pattern requires pads=[0,0,0,0]. + """ + + @script() + def model_fn(x: FLOAT[1, 3, 32, 32]) -> FLOAT[1, 3, 32, 32]: + w = op.Constant(value=_W_ONES) + b = op.Constant(value=_B_ONES) + scale = op.Constant(value_float=2.0) + offset = op.Constant(value_float=3.0) + affine = (x * scale) + offset + return op.Conv(affine, w, b, pads=[1, 1, 1, 1]) # non-zero pads + + model_proto = model_fn.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rewritten = self._clone(model) + rewritten = rewrite(rewritten, pattern_rewrite_rules=[affine_conv_fusion_rule]) + self.assertEqual(model.graph.num_nodes(), rewritten.graph.num_nodes()) + + # --- Positive: conv-affine fusion with all-constant operands --- + + def test_conv_affine_positive_fuses(self): + """Standard Conv -> Mul(scalar) -> Add(scalar) with constant w, b fuses correctly.""" + + @script() + def model_fn(x: FLOAT[1, 3, 32, 32]) -> FLOAT[1, 3, 32, 32]: + w = op.Constant(value=_W_ONES) + b = op.Constant(value=_B_ONES) + scale = op.Constant(value_float=2.0) + offset = op.Constant(value_float=3.0) + conv_out = op.Conv(x, w, b, pads=[1, 1, 1, 1]) + return (conv_out * scale) + offset + + model_proto = model_fn.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rewritten = self._clone(model) + rewritten = rewrite(rewritten, pattern_rewrite_rules=[conv_affine_fusion_rule]) + # Mul and Add should be fused — neither should remain in the graph + rewritten_ops = [n.op_type for n in rewritten.graph] + self.assertNotIn("Mul", rewritten_ops) + self.assertNotIn("Add", rewritten_ops) + self.assertIn("Conv", rewritten_ops) + + # Numerical validation + rng = np.random.default_rng(42) + inputs = [rng.random((1, 3, 32, 32), dtype=np.float32)] + testing.assert_numerically_equal(model, rewritten, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_redundant_scatter_nd_extended_test.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_extended_test.py new file mode 100644 index 0000000000..1b88d3cdd4 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_extended_test.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for redundant ScatterND rules. + +Adds coverage for: dynamic full-range scatter on axis=0 (positive), +partial scatter (negative for static), and shape mismatch (negative for static). +""" + +import unittest + +import numpy as np +import onnx.parser +import onnx_ir as ir +import onnxruntime +from onnx_ir.passes.common import CheckerPass, ShapeInferencePass + +from onnxscript import FLOAT, optimizer, script +from onnxscript import opset18 as op +from onnxscript.rewriter.rules.common import _redundant_scatter_nd + +N = "N" +shape_inference = ShapeInferencePass() +onnx_check = CheckerPass(True) + + +class RedundantScatterNdExtendedTest(unittest.TestCase): + """Extended tests for redundant ScatterND rewrite rules.""" + + # --- Positive: axis=0 dynamic --- + + def test_dynamic_indices_axis_0(self): + """Full-range scatter on axis=0 → should be eliminated.""" + + @script() + def model_fn(data: FLOAT[N, 16], updates: FLOAT[N, 16]) -> FLOAT[N, 16]: + axis = op.Constant(value_int=0) + shape = op.Shape(data, start=0) + dim = op.Gather(shape, axis, axis=0) + full_range = op.Range(0, dim, 1) + full_range_2d = op.Unsqueeze(full_range, [-1]) + scattered = op.ScatterND(data, full_range_2d, updates, reduction="none") + return scattered + + model_proto = model_fn.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + onnx_check(model) + shape_inference(model) + optimizer.fold_constants(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) + self.assertEqual(count, 1) + + # Verify numerical equivalence + inputs = { + "data": np.random.rand(8, 16).astype(np.float32), + "updates": np.random.rand(8, 16).astype(np.float32), + } + original_session = onnxruntime.InferenceSession( + model_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + original_outputs = original_session.run(None, inputs) + optimized_proto = ir.serde.serialize_model(model) + optimized_session = onnxruntime.InferenceSession( + optimized_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + optimized_outputs = optimized_session.run(None, inputs) + np.testing.assert_allclose(original_outputs[0], optimized_outputs[0]) + + # --- Negative: partial indices (static) --- + + def test_static_partial_indices_no_fusion(self): + """Indices that don't cover full dim → should NOT be eliminated.""" + model_proto = onnx.parser.parse_model( + """ + + agraph (float[8, 16] data, float[4, 16] updates) => (float[8, 16] output) + { + output = ScatterND (data, indices, updates) + } + """ + ) + # Only scatter to first 4 rows of 8 + indices = np.arange(4).reshape(4, 1).astype(np.int64) + model = ir.serde.deserialize_model(model_proto) + indices_value = model.graph[0].inputs[1] + indices_value.const_value = ir.Tensor(name="indices", value=indices) + model.graph.initializers["indices"] = indices_value + + count = _redundant_scatter_nd.rules.apply_to_model(model) + self.assertEqual(count, 0) + op_types = [n.op_type for n in model.graph] + self.assertIn("ScatterND", op_types) + + # --- Negative: shape mismatch (static) --- + + def test_static_shape_mismatch_no_fusion(self): + """data.shape != updates.shape → should NOT be eliminated.""" + model_proto = onnx.parser.parse_model( + """ + + agraph (float[8, 16] data, float[8, 32] updates) => (float[8, 16] output) + { + output = ScatterND (data, indices, updates) + } + """ + ) + indices = np.arange(8).reshape(8, 1).astype(np.int64) + model = ir.serde.deserialize_model(model_proto) + indices_value = model.graph[0].inputs[1] + indices_value.const_value = ir.Tensor(name="indices", value=indices) + model.graph.initializers["indices"] = indices_value + + count = _redundant_scatter_nd.rules.apply_to_model(model) + self.assertEqual(count, 0) + op_types = [n.op_type for n in model.graph] + self.assertIn("ScatterND", op_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/_gqa_extended_test.py b/onnxscript/rewriter/rules/fusion/_gqa_extended_test.py new file mode 100644 index 0000000000..8686b8bafb --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_gqa_extended_test.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for GQA fusion rule (rules/fusion variant). + +Adds coverage for: mismatched group parameter (negative). +""" + +from __future__ import annotations + +import unittest + +import onnx +import onnx_ir as ir +from packaging import version + +from onnxscript import FLOAT, optimizer, script, values +from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa +from onnxscript.rewriter.testing import assert_numerically_equal + +op = values.Opset("", 23) + +# Config: H=8, Hkv=4, D=64, G=2 +H = [8] +Hkv = [4] +D = [64] +G_CORRECT = [2] # H / Hkv = 8 / 4 = 2 +G_WRONG = [3] # Wrong group count + + +@script(ir_version=10) +def _gqa_wrong_group( + query_BHSD: FLOAT[2, 8, 4, 64], + key_BHkvSD: FLOAT[2, 4, 4, 64], + value_BHkvSD: FLOAT[2, 4, 4, 64], + past_key_BHkvPD: FLOAT[2, 4, 8, 64], + past_value_BHkvPD: FLOAT[2, 4, 8, 64], +) -> FLOAT[2, 8, 4, 64]: + """GQA pattern with wrong group count — should NOT fuse.""" + present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) + B = op.Shape(query_BHSD, start=0, end=1) + T = op.Shape(present_key_BHkvStD, start=2, end=3) + + # Use G_WRONG instead of G_CORRECT — expand shape will be [B, Hkv, 3, S+P, D] + expand_shape = op.Concat(B, Hkv, G_WRONG, T, D, axis=0) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) + + # Reshape target would be [B, Hkv*3, S+P, D] = [B, 12, ...] not [B, 8, ...] + H_wrong = [12] # Hkv * G_WRONG = 4 * 3 = 12 + reshape_shape = op.Concat(B, H_wrong, T, D, axis=0) + present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) + + present_value_BHkvStD = op.Concat(past_value_BHkvPD, value_BHkvSD, axis=-2) + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) + present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, expand_shape) + present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) + + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + ) + return attention_BHSDh + + +class GQAFusionExtendedTest(unittest.TestCase): + """Extended tests for GQA fusion.""" + + def test_gqa_wrong_group_count_no_fusion(self): + """Group count doesn't match H/Hkv — GQA fusion should fail.""" + model_proto = _gqa_wrong_group.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + optimizer.optimize(model) + count = fuse_gqa(model) + self.assertEqual( + count, 0, "GQA fusion should NOT succeed with mismatched group count." + ) + + def test_basic_gqa_positive_with_different_head_config(self): + """GQA with H=6, Hkv=2, G=3, D=32 — should still fuse.""" + H_cfg = [6] + Hkv_cfg = [2] + D_cfg = [32] + G_cfg = [3] + + @script(ir_version=10) + def gqa_alt( + q: FLOAT[1, 6, 4, 32], + k: FLOAT[1, 2, 4, 32], + v: FLOAT[1, 2, 4, 32], + pk: FLOAT[1, 2, 8, 32], + pv: FLOAT[1, 2, 8, 32], + ) -> FLOAT[1, 6, 4, 32]: + pk_cat = op.Concat(pk, k, axis=-2) + pk_unsq = op.Unsqueeze(pk_cat, 2) + B = op.Shape(q, start=0, end=1) + T = op.Shape(pk_cat, start=2, end=3) + expand_shape = op.Concat(B, Hkv_cfg, G_cfg, T, D_cfg, axis=0) + pk_exp = op.Expand(pk_unsq, expand_shape) + reshape_shape = op.Concat(B, H_cfg, T, D_cfg, axis=0) + pk_rsh = op.Reshape(pk_exp, reshape_shape) + + pv_cat = op.Concat(pv, v, axis=-2) + pv_unsq = op.Unsqueeze(pv_cat, 2) + pv_exp = op.Expand(pv_unsq, expand_shape) + pv_rsh = op.Reshape(pv_exp, reshape_shape) + + attn = op.Attention(q, pk_rsh, pv_rsh) + return attn + + model_proto = gqa_alt.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + optimizer.optimize(model) + count = fuse_gqa(model) + self.assertGreater(count, 0, "GQA fusion should succeed with H=6, Hkv=2, G=3, D=32.") + + # Verify numerical equivalence if onnx version supports it + onnx_ver = version.parse(onnx.__version__) + if onnx_ver >= version.parse("1.19.1") and not ( + onnx_ver.is_prerelease or onnx_ver.is_devrelease + ): + optimizer.remove_unused_nodes(model) + rewritten_proto = ir.serde.serialize_model(model) + assert_numerically_equal(model_proto, rewritten_proto, use_reference=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding_extended_test.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding_extended_test.py new file mode 100644 index 0000000000..e809a0d389 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding_extended_test.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Extended tests for RotaryEmbedding and PartialRotaryEmbedding fusion rules. + +Adds coverage for: PartialRotaryEmbedding23 positive test, boundary mismatch +negative, and "rotary_embedding_dim already set" negative. +""" + +from __future__ import annotations + +import unittest + +import numpy +import onnx +import onnx_ir as ir +from packaging.version import Version + +import onnxscript +import onnxscript.rewriter.testing +from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rotary_embedding + +_ROTARY_DIM = 32 + +_INPUT_TYPES = ( + onnxscript.INT64["Batchsize", "Sequence"], + onnxscript.FLOAT["Batchsize", 32, "Sequence", 80], +) +_OUTPUT_TYPES = (onnxscript.FLOAT["Batchsize", 32, "Sequence", 80],) + + +class PartialRotaryEmbeddingExtendedTest(unittest.TestCase): + """Extended tests for PartialRotaryEmbedding23 fusion rule.""" + + def _get_partial_model(self, *, mismatched: bool = False) -> ir.Model: + """Get a fresh partial rotary embedding model.""" + script_fn = _rotary_embedding_models._make_partial_rotary_script(mismatched=mismatched) + model_proto = script_fn.to_model_proto( + input_types=_INPUT_TYPES, output_types=_OUTPUT_TYPES + ) + return ir.serde.deserialize_model(model_proto) + + def test_partial_rotary_embedding_fused(self): + """Full rotary embedding + partial concat → fuse into RotaryEmbedding with rotary_embedding_dim.""" + model = self._get_partial_model() + model.graph.opset_imports[""] = 23 + + original_proto = ir.serde.serialize_model(model) + + onnxscript.optimizer.optimize(model) + # First fuse the base rotary embedding + count = _rotary_embedding.fuse_rotary_embedding(model) + self.assertGreater(count, 0, "Base RotaryEmbedding fusion should succeed first.") + + # Then fuse partial rotary embedding + count_partial = _rotary_embedding.fuse_partial_rotary_embedding(model) + self.assertGreater(count_partial, 0, "PartialRotaryEmbedding fusion should succeed.") + + # Verify RotaryEmbedding has rotary_embedding_dim attribute + rope_nodes = [n for n in model.graph if n.op_type == "RotaryEmbedding"] + self.assertTrue(len(rope_nodes) > 0, "Should have RotaryEmbedding node.") + rope_node = rope_nodes[0] + self.assertIn("rotary_embedding_dim", rope_node.attributes) + self.assertEqual(rope_node.attributes["rotary_embedding_dim"].value, _ROTARY_DIM) + + # Numerical validation via reference implementation (if onnx version supports it) + rewritten_proto = ir.serde.serialize_model(model) + onnx_ver = Version(onnx.__version__) + if onnx_ver >= Version("1.19.1") and not ( + onnx_ver.is_devrelease or onnx_ver.is_prerelease + ): + inputs = { + "query": numpy.random.rand(1, 32, 8, 80).astype(numpy.float32), + "position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8), + } + onnxscript.rewriter.testing.assert_numerically_equal( + original_proto, rewritten_proto, args=inputs, use_reference=True + ) + + def test_partial_rotary_mismatched_boundaries_no_fusion(self): + """When end1 != start2 in partial slice, PartialRotaryEmbedding should NOT fuse.""" + model = self._get_partial_model(mismatched=True) + model.graph.opset_imports[""] = 23 + + onnxscript.optimizer.optimize(model) + # First fuse the base rotary embedding + count = _rotary_embedding.fuse_rotary_embedding(model) + self.assertGreater(count, 0) + + # Partial fusion should fail because end1=32 but start2=33 + count_partial = _rotary_embedding.fuse_partial_rotary_embedding(model) + self.assertEqual(count_partial, 0, "Should NOT fuse with mismatched boundaries.") + + def test_partial_rotary_already_has_dim_attr_no_fusion(self): + """If RotaryEmbedding already has rotary_embedding_dim, partial fusion should NOT apply.""" + model = self._get_partial_model() + model.graph.opset_imports[""] = 23 + + onnxscript.optimizer.optimize(model) + # Fuse base rotary embedding + count = _rotary_embedding.fuse_rotary_embedding(model) + self.assertGreater(count, 0) + + # Add rotary_embedding_dim attribute to the RotaryEmbedding node + for node in model.graph: + if node.op_type == "RotaryEmbedding": + node.attributes["rotary_embedding_dim"] = ir.AttrInt64( + "rotary_embedding_dim", 16 + ) + break + + # Partial fusion should refuse to fuse + count_partial = _rotary_embedding.fuse_partial_rotary_embedding(model) + self.assertEqual( + count_partial, 0, "Should NOT fuse when rotary_embedding_dim already set." + ) + + +if __name__ == "__main__": + unittest.main()