Skip to content

Commit 87be039

Browse files
gramalingamCopilot
andcommitted
Add unit tests for RmsNormFusion rules (rms_normalization.py)
4 tests covering: - Both Mul orderings: scale*normalized and normalized*scale (parameterized) - Mixed-precision: fp16 input with fp32 compute via Cast - Integer input dtype rejected (negative) All positive tests include numerical validation via ORT. Uses symbolic dims ("B", "S") and @script() model construction. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent f114eb9 commit 87be039

1 file changed

Lines changed: 171 additions & 0 deletions

File tree

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Unit tests for RmsNormFusion rules (rms_normalization.py).
5+
6+
The rule detects the RMS-normalization pattern:
7+
x_norm = x / sqrt(mean(x^2) + eps)
8+
output = x_norm * scale
9+
and fuses it into SimplifiedLayerNormalization.
10+
11+
Covers both mul-orderings, optional Casts (mixed-precision),
12+
and negative cases (bad dtype, non-scalar epsilon).
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import unittest
18+
19+
import numpy as np
20+
import onnx_ir as ir
21+
from parameterized import parameterized
22+
23+
from onnxscript import FLOAT, FLOAT16, script
24+
from onnxscript import opset18 as op
25+
from onnxscript.optimizer import optimize
26+
from onnxscript.rewriter.ort_fusions import _test_utils as test_utils
27+
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
28+
29+
_B, _S, _D = 2, 8, 16
30+
_EPS = ir.tensor(np.array([1e-6], dtype=np.float32))
31+
32+
33+
# --- Pattern: Mul(scale, normalized) — mul_order=False ---
34+
35+
36+
@script()
37+
def _rms_norm_scale_first(x, scale):
38+
x_sq = op.Pow(x, 2.0)
39+
mean_sq = op.ReduceMean(x_sq, [-1], keepdims=1, noop_with_empty_axes=0)
40+
eps = op.Constant(value=_EPS)
41+
rms = op.Sqrt(op.Add(mean_sq, eps))
42+
inv_rms = op.Reciprocal(rms)
43+
normalized = op.Mul(x, inv_rms)
44+
return op.Mul(scale, normalized)
45+
46+
47+
# --- Pattern: Mul(normalized, scale) — mul_order=True ---
48+
49+
50+
@script()
51+
def _rms_norm_norm_first(x, scale):
52+
x_sq = op.Pow(x, 2.0)
53+
mean_sq = op.ReduceMean(x_sq, [-1], keepdims=1, noop_with_empty_axes=0)
54+
eps = op.Constant(value=_EPS)
55+
rms = op.Sqrt(op.Add(mean_sq, eps))
56+
inv_rms = op.Reciprocal(rms)
57+
normalized = op.Mul(x, inv_rms)
58+
return op.Mul(normalized, scale)
59+
60+
61+
# --- Pattern with Cast on input (mixed-precision: fp16 input, fp32 compute) ---
62+
63+
64+
@script()
65+
def _rms_norm_with_cast_input(x, scale):
66+
x_f32 = op.Cast(x, to=ir.DataType.FLOAT)
67+
x_sq = op.Pow(x_f32, 2.0)
68+
mean_sq = op.ReduceMean(x_sq, [-1], keepdims=1, noop_with_empty_axes=0)
69+
eps = op.Constant(value=_EPS)
70+
rms = op.Sqrt(op.Add(mean_sq, eps))
71+
inv_rms = op.Reciprocal(rms)
72+
normalized = op.Mul(x_f32, inv_rms)
73+
result = op.Cast(normalized, to=ir.DataType.FLOAT16)
74+
return op.Mul(result, scale)
75+
76+
77+
# --- Negative: integer input ---
78+
79+
80+
@script()
81+
def _rms_norm_int_input(x, scale):
82+
x_f = op.Cast(x, to=ir.DataType.FLOAT)
83+
x_sq = op.Pow(x_f, 2.0)
84+
mean_sq = op.ReduceMean(x_sq, [-1], keepdims=1, noop_with_empty_axes=0)
85+
eps = op.Constant(value=_EPS)
86+
rms = op.Sqrt(op.Add(mean_sq, eps))
87+
inv_rms = op.Reciprocal(rms)
88+
normalized = op.Mul(x_f, inv_rms)
89+
return op.Mul(normalized, scale)
90+
91+
92+
class RmsNormFusionTest(unittest.TestCase):
93+
"""Unit tests for RmsNormFusion rewrite rules."""
94+
95+
def _build(self, script_fn, input_types, output_types) -> ir.Model:
96+
model_proto = script_fn.to_model_proto(
97+
input_types=input_types, output_types=output_types
98+
)
99+
model = ir.serde.deserialize_model(model_proto)
100+
optimize(model)
101+
return model
102+
103+
def _apply(self, model: ir.Model) -> int:
104+
return fuse_rms_normalization(model)
105+
106+
def _count_op(self, model: ir.Model, op_type: str) -> int:
107+
return sum(1 for n in model.graph if n.op_type == op_type)
108+
109+
def _check_numerical_equivalence(self, model: ir.Model, inputs: dict, expected_count: int):
110+
original_output = test_utils.ort_run("Original", model, inputs)
111+
count = self._apply(model)
112+
self.assertEqual(count, expected_count)
113+
fused_output = test_utils.ort_run("Fused", model, inputs)
114+
test_utils.assert_allclose(original_output, fused_output)
115+
116+
# --- Positive tests ---
117+
118+
@parameterized.expand(
119+
[
120+
("scale_times_normalized", _rms_norm_scale_first),
121+
("normalized_times_scale", _rms_norm_norm_first),
122+
]
123+
)
124+
def test_mul_order_variants(self, _name, script_fn):
125+
"""Both Mul orderings (scale*norm and norm*scale) should fuse."""
126+
model = self._build(
127+
script_fn,
128+
input_types=[FLOAT["B", "S", _D], FLOAT[_D]],
129+
output_types=[FLOAT["B", "S", _D]],
130+
)
131+
inputs = {
132+
"x": np.random.randn(_B, _S, _D).astype(np.float32),
133+
"scale": np.random.randn(_D).astype(np.float32),
134+
}
135+
self._check_numerical_equivalence(model, inputs, expected_count=1)
136+
self.assertEqual(self._count_op(model, "SimplifiedLayerNormalization"), 1)
137+
self.assertEqual(self._count_op(model, "Pow"), 0)
138+
self.assertEqual(self._count_op(model, "ReduceMean"), 0)
139+
140+
def test_cast_input_mixed_precision(self):
141+
"""fp16 input Cast to fp32 for compute, Cast back → still fuses."""
142+
model = self._build(
143+
_rms_norm_with_cast_input,
144+
input_types=[FLOAT16["B", "S", _D], FLOAT16[_D]],
145+
output_types=[FLOAT16["B", "S", _D]],
146+
)
147+
inputs = {
148+
"x": np.random.randn(_B, _S, _D).astype(np.float16),
149+
"scale": np.random.randn(_D).astype(np.float16),
150+
}
151+
self._check_numerical_equivalence(model, inputs, expected_count=1)
152+
self.assertEqual(self._count_op(model, "SimplifiedLayerNormalization"), 1)
153+
154+
# --- Negative tests ---
155+
156+
def test_int_input_no_fusion(self):
157+
"""Integer input dtype → check rejects (x.dtype not in float_types)."""
158+
from onnxscript import INT32
159+
160+
model = self._build(
161+
_rms_norm_int_input,
162+
input_types=[INT32["B", "S", _D], FLOAT[_D]],
163+
output_types=[FLOAT["B", "S", _D]],
164+
)
165+
count = self._apply(model)
166+
self.assertEqual(count, 0)
167+
self.assertEqual(self._count_op(model, "SimplifiedLayerNormalization"), 0)
168+
169+
170+
if __name__ == "__main__":
171+
unittest.main()

0 commit comments

Comments
 (0)