55from typing import Sequence , Union
66
77import onnxscript .ir as ir
8- from onnxscript .rewriter import _fusion_utils , pattern
8+ from onnxscript .rewriter import _fusion_utils , _ir_utils , pattern
99
1010Dim = Union [int , ir .SymbolicDim ]
1111
@@ -36,6 +36,12 @@ def pattern(
3636 attention_bias ,
3737 num_heads ,
3838 # scale,
39+ start1 ,
40+ end1 ,
41+ start2 ,
42+ end2 ,
43+ start3 ,
44+ end3 ,
3945 q_mul ,
4046 k_mul ,
4147 v_mul ,
@@ -45,28 +51,28 @@ def pattern(
4551 key_BSD = op .MatMul (input , k_mul )
4652 value_BSD = op .MatMul (input , v_mul )
4753 else :
48- projected = op .MatMul (input , qkv_weight )
54+ projected = op .MatMul (input , qkv_weight , _outputs = [ "projected" ] )
4955
5056 # Slice packed Matmul QKV into Q, K, and V
5157 # Q, K, and V are of shape (B, S, D)
5258 query_BSD = op .Slice (
5359 projected ,
54- pattern . ANY_VALUE , # starts
55- pattern . ANY_VALUE , # ends
60+ start1 , # starts
61+ end1 , # ends
5662 [2 ], # axes
5763 _outputs = ["query_mm_sliced" ],
5864 )
5965 key_BSD = op .Slice (
6066 projected ,
61- pattern . ANY_VALUE , # starts
62- pattern . ANY_VALUE , # ends
67+ start2 , # starts
68+ end2 , # ends
6369 [2 ], # axes
6470 _outputs = ["key_mm_sliced" ],
6571 )
6672 value_BSD = op .Slice (
6773 projected ,
68- pattern . ANY_VALUE , # starts
69- pattern . ANY_VALUE , # ends
74+ start3 , # starts
75+ end3 , # ends
7076 [2 ], # axes
7177 _outputs = ["value_mm_sliced" ],
7278 )
@@ -135,9 +141,16 @@ def check(
135141 op ,
136142 input ,
137143 qkv_weight ,
144+ projected = None ,
138145 query_mm_sliced = None ,
139146 key_mm_sliced = None ,
140147 value_mm_sliced = None ,
148+ start1 = None ,
149+ end1 = None ,
150+ start2 = None ,
151+ end2 = None ,
152+ start3 = None ,
153+ end3 = None ,
141154 q_mul = None ,
142155 k_mul = None ,
143156 v_mul = None ,
@@ -155,6 +168,23 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
155168 input ,
156169 )
157170 if not self ._no_slice :
171+ # Ensure slicing is done correctly
172+ if projected is None or projected .shape is None or len (projected .shape ) != 3 :
173+ return check_result .fail ("Input projection is not a 3D tensor." , projected )
174+ hidden_size = projected .shape [2 ]
175+ if not isinstance (hidden_size , int ):
176+ return check_result .fail ("Hidden size is not an integer." , projected )
177+ if not (
178+ _ir_utils .is_singleton_value (start1 , 0 )
179+ and _ir_utils .get_singleton_value (end1 ) == _ir_utils .get_singleton_value (start2 )
180+ and _ir_utils .get_singleton_value (end2 ) == _ir_utils .get_singleton_value (start3 )
181+ and _ir_utils .is_singleton_value (end3 , lambda x : x >= hidden_size )
182+ ):
183+ return check_result .fail (
184+ "Projected input is not being split into q, k, v correctly based on hidden sizes." ,
185+ projected ,
186+ )
187+
158188 if no_match (qkv_weight , ["D" , "Dh" ]):
159189 return check_result .fail (
160190 f"Shape mismatch: { qkv_weight } does not match expected dimensions ['D', 'Dh']" ,
0 commit comments