1111from onnxscript .rewriter import pattern as orp
1212
1313
14+ class SqueezeReshape (orp .RewriteRuleClassBase ):
15+ """Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
16+
17+ This pattern arises from the translation of pytorch symints.
18+ """
19+
20+ def __init__ (self ):
21+ super ().__init__ ("SqueezeReshape1d" , remove_nodes = False )
22+
23+ def pattern (self , op , x ):
24+ return op .Reshape (op .Squeeze (x ), [- 1 ])
25+
26+ def rewrite (self , op , x : ir .Value ):
27+ return op .Identity (x )
28+
29+ def check (self , context , x ) -> orp .MatchResult :
30+ del context # Unused
31+ check_result = orp .MatchResult ()
32+ if not ir_utils .has_rank (x , 1 ):
33+ return check_result .fail ("Input is not 1D" )
34+ return check_result
35+
36+
1437class CastIdentity (orp .RewriteRuleAsClass ):
1538 """Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
1639
@@ -23,8 +46,11 @@ def rewrite(cls, op, x: ir.Value, to: ir.Attr):
2346 return op .Identity (x )
2447
2548 @classmethod
26- def check (cls , context , x , to ) -> bool :
27- return x .dtype == to .value
49+ def check (cls , context , x , to ) -> orp .MatchResult :
50+ check_result = orp .MatchResult ()
51+ if x .dtype != to .value :
52+ return check_result .fail ("Input and output types are not the same" )
53+ return check_result
2854
2955
3056class CastCast (orp .RewriteRuleAsClass ):
@@ -42,11 +68,13 @@ def pattern(cls, op, x, to, to_ignored):
4268 return op .Cast (op .Cast (x , to = to_ignored ), to = to )
4369
4470 @classmethod
45- def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> bool :
46- return (
47- to .value in cls ._allowed_tensor_types
48- and to_ignored .value in cls ._allowed_tensor_types
49- )
71+ def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
72+ check_result = orp .MatchResult ()
73+ if to .value not in cls ._allowed_tensor_types :
74+ return check_result .fail (f"Output type { to .value } is not allowed" )
75+ if to_ignored .value not in cls ._allowed_tensor_types :
76+ return check_result .fail (f"Ignored type { to_ignored .value } is not allowed" )
77+ return check_result
5078
5179 @classmethod
5280 def rewrite (cls , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
@@ -65,14 +93,19 @@ def rewrite(cls, op, x: ir.Value, shape: ir.Value):
6593 return op .Identity (x )
6694
6795 @classmethod
68- def check (cls , context , x , shape ) -> bool :
96+ def check (cls , context , x , shape ) -> orp .MatchResult :
97+ check_result = orp .MatchResult ()
6998 if shape .const_value is None :
7099 # Shape is not a constant and cannot be guessed.
71- return False
100+ return check_result . fail ( "Shape is not a constant and cannot be guessed." )
72101 if (x_shape := x .shape ) is None :
73102 # We don't know the shape of the input
74- return False
75- return x_shape .dims == tuple (shape .const_value .numpy ().tolist ())
103+ return check_result .fail ("Input shape is not known." )
104+ if x_shape .dims != tuple (shape .const_value .numpy ().tolist ()):
105+ return check_result .fail (
106+ f"Input shape { x_shape .dims } does not match the shape { shape .const_value .numpy ().tolist ()} ."
107+ )
108+ return check_result
76109
77110
78111class ReshapeReshape (orp .RewriteRuleAsClass ):
@@ -90,12 +123,15 @@ def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
90123 return op .Reshape (x , shape )
91124
92125 @classmethod
93- def check (cls , context , x , shape_ignored , shape ) -> bool :
94- if shape_ignored .const_value is None or shape .const_value is None :
95- return False
126+ def check (cls , context , x , shape_ignored , shape ) -> orp .MatchResult :
127+ check_result = orp .MatchResult ()
128+ if shape_ignored .const_value is None :
129+ return check_result .fail ("Shape ignored is not a constant." )
130+ if shape .const_value is None :
131+ return check_result .fail ("Shape is not a constant." )
96132 if shape .const_value .numpy ().min () <= 0 :
97- return False
98- return True
133+ return check_result . fail ( "Shape has non-positive values." )
134+ return check_result
99135
100136
101137class SlicesSplit (orp .RewriteRuleAsClass ):
@@ -108,49 +144,50 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
108144 return op .Slice (x , begin0 , end0 , axes0 ), op .Slice (x , begin1 , end1 , axes1 )
109145
110146 @classmethod
111- def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> bool :
147+ def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
148+ check_result = orp .MatchResult ()
112149 if (
113150 axes0 .const_value is None
114151 or axes1 .const_value is None
115152 or axes0 .const_value .numpy ().tolist () != axes1 .const_value .numpy ().tolist ()
116153 ):
117- return False
154+ return check_result . fail ( "Axes are not equal or not constant." )
118155 axes = axes0 .const_value .numpy ().tolist ()
119156 if len (axes ) != 1 :
120- return False
157+ return check_result . fail ( "Axes has more than one dimension." )
121158 if x .shape :
122159 rk = len (x .shape )
123160 else :
124161 rk = x .rank
125162 if axes [0 ] != - 1 and axes [0 ] != rk - 1 :
126- return False
163+ return check_result . fail ( "Axes is not -1 or last dimension." )
127164 if (
128165 begin0 .const_value is None
129166 or end0 .const_value is None
130167 or begin1 .const_value is None
131168 or end1 .const_value is None
132169 ):
133- return False
170+ return check_result . fail ( "Begin or end are not constant values." )
134171 if begin0 .const_value .numpy ().tolist () != [0 ]:
135- return False
172+ return check_result . fail ( "First begin value is not 0." )
136173 e0 , b1 , e1 = (
137174 end0 .const_value .numpy ().tolist (),
138175 begin1 .const_value .numpy ().tolist (),
139176 end1 .const_value .numpy ().tolist (),
140177 )
141178 if e0 [0 ] != b1 [0 ]:
142- return False
179+ return check_result . fail ( "End0 is not equal to Begin1." )
143180 shape = x .shape
144181 if shape is None :
145- return False
182+ return check_result . fail ( "Shape is not known." )
146183 last_dim = shape [- 1 ]
147184 if not isinstance (last_dim , int ):
148- return False
185+ return check_result . fail ( "Last dimension is not known." )
149186 if last_dim != e1 [0 ]:
150- return False
187+ return check_result . fail ( "Last dimension is not equal to End1." )
151188 if last_dim // 2 != b1 [0 ]:
152- return False
153- return True
189+ return check_result . fail ( "Last dimension is not equal to Begin1." )
190+ return check_result
154191
155192 @classmethod
156193 def rewrite (cls , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
@@ -167,13 +204,14 @@ def pattern(cls, op, x, perm):
167204 return op .Transpose (x , perm = perm )
168205
169206 @classmethod
170- def check (cls , context , x : ir .Value , perm : ir .Attr ) -> bool :
207+ def check (cls , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
208+ check_result = orp .MatchResult ()
171209 if isinstance (perm , ir .RefAttr ):
172- return False
210+ return check_result . fail ( "Permutation is a reference attribute." )
173211 if perm .type == ir .AttributeType .INTS :
174212 if perm .value == list (range (len (perm .value ))):
175- return True
176- return False
213+ return check_result
214+ return check_result . fail ( "Permutation is not identity." )
177215
178216 @classmethod
179217 def rewrite (cls , op , x : ir .Value , perm : ir .Attr ):
@@ -190,10 +228,11 @@ def pattern(cls, op, x, perm1, perm2):
190228 return op .Transpose (op .Transpose (x , perm = perm1 ), perm = perm2 )
191229
192230 @classmethod
193- def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> bool :
231+ def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
232+ check_result = orp .MatchResult ()
194233 if isinstance (perm1 , ir .RefAttr ) or isinstance (perm2 , ir .RefAttr ):
195- return False
196- return True
234+ return check_result . fail ( "Permutation is a reference attribute." )
235+ return check_result
197236
198237 @classmethod
199238 def _apply_transpose (cls , perm : tuple [int , ...], on : list [int ]) -> list [int ]:
@@ -237,17 +276,18 @@ def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
237276 return op .Unsqueeze (x , op .Constant (value = ir .tensor (axes , dtype = ir .DataType .INT64 )))
238277
239278 @classmethod
240- def check (cls , context , x , axes1 , axes2 ) -> bool :
279+ def check (cls , context , x , axes1 , axes2 ) -> orp .MatchResult :
280+ check_result = orp .MatchResult ()
241281 del context # Unused
242282 del x # Unused
243283 # Currently restricted to single element positive axis
244284 v1 = ir_utils .get_singleton_value (axes1 )
245285 v2 = ir_utils .get_singleton_value (axes2 )
246286 if v1 is None or v2 is None :
247- return False
287+ return check_result . fail ( "Axes are not constant." )
248288 if (v1 < 0 ) or (v2 < 0 ):
249- return False
250- return True
289+ return check_result . fail ( "Axes are negative." )
290+ return check_result
251291
252292
253293cast_cast_rule = orp .make_rewrite_rule_from_class (CastCast )
0 commit comments