Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions onnxscript/rewriter/models/_rotary_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,42 @@ def get_ort_inputs(self):

def partial_rotary_test_case():
return _PartialRotaryTestCase()


def _make_partial_rotary_script(mismatched: bool = False):
"""Generate a partial rotary embedding script with matching or mismatched slice boundaries.

Args:
mismatched: If True, the second slice starts at 33 instead of 32,
creating a gap that should prevent PartialRotaryEmbedding fusion.
"""

@script()
def partial_rotary(position_ids, query):
inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1]
position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S]
position_ids_3d_float = op.Cast(position_ids_3d, to=1)
matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S]
transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2]
cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd]
cos_3d = op.Cos(cat) # [B, S, rd]
sin_3d = op.Sin(cat) # [B, S, rd]
# Split the query for partial embedding
to_embed = op.Slice(query, [0], [32], [3], [1])
if mismatched:
unembedded = op.Slice(query, [33], [9223372036854775807], [3], [1])
else:
unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1])
cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd]
sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd]
to_embed_times_cos = op.Mul(to_embed, cos_4d)
to_embed_x = op.Slice(to_embed, [0], [16], [3], [1])
to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1])
minus_to_embed_y = op.Neg(to_embed_y)
to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1)
to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d)
embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin)
final = op.Concat(embedded, unembedded, axis=-1)
return final

return partial_rotary
122 changes: 122 additions & 0 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache_extended_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Extended tests for cos_sin_cache fusion.

Adds coverage for: non-constant inv_freq (negative — dynamic inv_freq prevents cache precomputation).
"""

from __future__ import annotations

import unittest

import numpy
import onnx_ir as ir

from onnxscript import FLOAT, INT64, optimizer, script, values
from onnxscript import opset18 as op
from onnxscript.rewriter.ort_fusions._test_utils import ort_run
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding

msft_op = values.Opset("com.microsoft", 1)


class CosSinCacheExtendedTest(unittest.TestCase):
"""Extended tests for cos_sin_cache fusion."""

def test_non_constant_inv_freq_no_cache_fusion(self):
"""When inv_freq is a graph input (not constant), cos_sin_cache should not fuse.

The cos_sin_cache fusion relies on inv_freq being a constant to precompute
the cos/sin lookup table. When inv_freq is dynamic, no fusion should apply.
"""

@script()
def model_with_dynamic_inv_freq(
x: FLOAT[1, 4, 8, 8],
position_ids: INT64[1, 8],
inv_freq: FLOAT[1, 4, 1],
) -> FLOAT[1, 4, 8, 8]:
# inv_freq is a graph input, not a constant
position_ids_expanded = op.Unsqueeze(position_ids, [1])
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq, position_ids_float)
freqs_t = op.Transpose(freqs, perm=[0, 2, 1])
emb = op.Concat(freqs_t, freqs_t, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, [1])
sin_4d = op.Unsqueeze(sin, [1])

x1 = op.Slice(x, [0], [4], [3], [1])
x2 = op.Slice(x, [4], [8], [3], [1])
minus_x2 = op.Neg(x2)
rotated_x = op.Concat(minus_x2, x1, axis=-1)
result = op.Add(x * cos_4d, rotated_x * sin_4d)
return result

model_proto = model_with_dynamic_inv_freq.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
optimizer.optimize(model)

# Rotary embedding fusion should still work
re_count = fuse_rotary_embedding(model)
self.assertGreater(re_count, 0, "RotaryEmbedding fusion should succeed.")

# cos_sin_cache fusion should NOT work because inv_freq is not constant
cache_count = fuse_cos_sin_cache(model)
self.assertEqual(
cache_count, 0, "cos_sin_cache should NOT fuse with dynamic inv_freq."
)

def test_constant_inv_freq_does_fuse(self):
"""Sanity check: constant inv_freq allows cos_sin_cache fusion."""

@script()
def model_with_const_inv_freq(
x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]
) -> FLOAT[1, 4, 8, 8]:
inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0])
inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2])
position_ids_expanded = op.Unsqueeze(position_ids, [1])
position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq_3d, position_ids_float)
freqs_t = op.Transpose(freqs, perm=[0, 2, 1])
emb = op.Concat(freqs_t, freqs_t, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, [1])
sin_4d = op.Unsqueeze(sin, [1])

x1 = op.Slice(x, [0], [4], [3], [1])
x2 = op.Slice(x, [4], [8], [3], [1])
minus_x2 = op.Neg(x2)
rotated_x = op.Concat(minus_x2, x1, axis=-1)
result = op.Add(x * cos_4d, rotated_x * sin_4d)
return result

model_proto = model_with_const_inv_freq.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
optimizer.optimize(model)

inputs = {
"x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32),
"position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8),
}
original_outputs = ort_run("original", model, inputs)

re_count = fuse_rotary_embedding(model)
self.assertGreater(re_count, 0)
cache_count = fuse_cos_sin_cache(model)
self.assertGreater(cache_count, 0, "cos_sin_cache should fuse with constant inv_freq.")

# Numerical validation
fused_outputs = ort_run("fused", model, inputs)
numpy.testing.assert_allclose(
original_outputs[0], fused_outputs[0], rtol=1e-5, atol=1e-5
)


if __name__ == "__main__":
unittest.main()
142 changes: 142 additions & 0 deletions onnxscript/rewriter/ort_fusions/gqa_extended_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Extended tests for GQA ort_fusions fusion.

Adds coverage for: missing RotaryEmbedding (negative — no GQA fusion).
"""

from __future__ import annotations

import math
import unittest

import onnx
import onnx_ir as ir
import onnx_ir.passes.common.shape_inference as shape_inference

from onnxscript import FLOAT, optimizer, script, values
from onnxscript import opset18 as op
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa

msft_op = values.Opset("com.microsoft", 1)


class GQAOrtFusionExtendedTest(unittest.TestCase):
"""Extended negative test for GQA ort_fusion."""

def test_no_rotary_embedding_no_gqa_fusion(self):
"""GQA source pattern without RotaryEmbedding → SDPA may fuse but GQA should NOT.

The ort_fusions GQA rule requires RotaryEmbedding nodes to fuse the full GQA
pattern (query/key rotary + kv-cache concat + expand + SDPA → GQA).
Without rotary embedding, the GQA-specific fusion should not trigger.
"""
head_size = 16
num_heads = 20
kv_num_heads = 10
hidden_size = head_size * num_heads
kv_hidden_size = head_size * kv_num_heads
num_groups = num_heads // kv_num_heads

H = [num_heads]
Hkv = [kv_num_heads]
Dh = [head_size]
G = [num_groups]
minus_1 = [-1]

scale_factor = math.sqrt(math.sqrt(head_size))

@script()
def gqa_no_rotary(query, key, value, past_key, past_value):
B = op.Shape(query, start=0, end=1)
S = op.Shape(query, start=1, end=2)
past_seq_length = op.Shape(past_key, start=2, end=3)
total_seq_length = op.Add(past_seq_length, S)

shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0)
shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0)
shape_BSD = op.Concat(B, S, minus_1, axis=0)
shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0)
shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0)

query_BSHDh = op.Reshape(query, shape_BSHDh)
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])

key_BSHkvDh = op.Reshape(key, shape_BSHkvDh)
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])

value_BSHkvDh = op.Reshape(value, shape_BSHkvDh)
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])

# NO RotaryEmbedding here — just use key/query directly
key_seq = op.Concat(past_key, key_BHkvSDh, axis=-2)
value_seq = op.Concat(past_value, value_BHkvSDh, axis=-2)

key_unsq = op.Unsqueeze(key_seq, [2])
key_exp = op.Expand(key_unsq, shape_BHkvGSDh)
key_rsh = op.Reshape(key_exp, shape_BHSDh)

value_unsq = op.Unsqueeze(value_seq, [2])
value_exp = op.Expand(value_unsq, shape_BHkvGSDh)
value_rsh = op.Reshape(value_exp, shape_BHSDh)

# Attention
key_t = op.Transpose(key_rsh, perm=[0, 1, 3, 2])
divisor = op.Constant(value_float=scale_factor)
scaled_q = op.Div(query_BHSDh, divisor)
scaled_k = op.Div(key_t, divisor)
score = op.MatMul(scaled_q, scaled_k)
weight = op.Softmax(score, axis=-1)
attn = op.MatMul(weight, value_rsh)

attn_t = op.Transpose(attn, perm=[0, 2, 1, 3])
attn_out = op.Reshape(attn_t, shape_BSD)

return attn_out, key_seq, value_seq

D = hidden_size
Dkv_val = kv_hidden_size
Dh_val = head_size
Hkv_val = kv_num_heads

input_types = (
FLOAT["B", "S", D],
FLOAT["B", "S", Dkv_val],
FLOAT["B", "S", Dkv_val],
FLOAT["B", Hkv_val, "P", Dh_val],
FLOAT["B", Hkv_val, "P", Dh_val],
)
output_types = (
FLOAT["B", "S", D],
FLOAT["B", Hkv_val, "T", Dh_val],
FLOAT["B", Hkv_val, "T", Dh_val],
)

source_model = gqa_no_rotary.to_model_proto(
input_types=input_types,
output_types=output_types,
)

# Add value_info for shapes needed by fusion
query_BSHDh_vi = onnx.helper.make_tensor_value_info(
"query_BSHDh", onnx.TensorProto.FLOAT, ["B", "S", num_heads, head_size]
)
key_BSHkvDh_vi = onnx.helper.make_tensor_value_info(
"key_BSHkvDh", onnx.TensorProto.FLOAT, ["B", "S", kv_num_heads, head_size]
)
source_model.graph.value_info.extend([query_BSHDh_vi, key_BSHkvDh_vi])

model = ir.serde.from_proto(source_model)
inferred = shape_inference.infer_shapes(model)
optimizer.optimize(inferred)

# SDPA might fuse, but GQA should not (no RotaryEmbedding)
fuse_sdpa(inferred, debug=False)
count = fuse_gqa(inferred, debug=False)
self.assertEqual(count, 0, "GQA fusion should NOT succeed without RotaryEmbedding.")


if __name__ == "__main__":
unittest.main()
Loading
Loading