@@ -42,7 +42,7 @@ def __init__(
4242 is_rotary : bool ,
4343 use_mask : bool ,
4444 has_past_present : bool ,
45- is_cross_attention_from_past : bool ,
45+ is_cross_attention : bool ,
4646 ):
4747 super ().__init__ (name )
4848 self ._double_transpose = double_transpose
@@ -51,20 +51,14 @@ def __init__(
5151 self ._is_rotary = is_rotary
5252 self ._use_mask = use_mask
5353 self ._has_past_present = has_past_present
54- # Checks for cross-attention pattern when cross
55- # query and key originate from past_key and past_value.
56- self ._is_cross_attention_from_past = is_cross_attention_from_past
57- # Store the key/value to check if the cross-attention is
58- # indeed from past_key and past_value.
59- self ._k_from_past = None
60- self ._v_from_past = None
54+ self ._is_cross_attention = is_cross_attention
6155
6256 def pattern (
6357 self ,
6458 op ,
6559 query_BSD ,
66- key_BSD ,
67- value_BSD ,
60+ key ,
61+ value ,
6862 mask ,
6963 past_key ,
7064 past_value ,
@@ -83,23 +77,28 @@ def pattern(
8377 # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
8478 query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
8579
86- # Reshape from (B, S, D) to (B, S, H, D/H)
87- key_BSHDh = op .Reshape (key_BSD , pattern .ANY_VALUE , _outputs = ["key_BSHDh" ])
88-
89- # Possible Transpose patterns for key:
90- # This scenario optimizes the need for a double transpose
91- # 1. (B, S, H, D/H) -> (B, H, D/H, S)
92- # Patterns with double transpose of key
93- # Double transpose should handle this optimization
94- # 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
95- # Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
96- # 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
97- key_BHSDh = op .Transpose (key_BSHDh , perm = key_perm )
98-
99- # Reshape from (B, S, D) to (B, S, H, D/H)
100- value_BSHDh = op .Reshape (value_BSD , pattern .ANY_VALUE , _outputs = ["value_BSHDh" ])
101- # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
102- value_BHSDh = op .Transpose (value_BSHDh , perm = [0 , 2 , 1 , 3 ])
80+ if not self ._is_cross_attention :
81+ # Reshape from (B, S, D) to (B, S, H, D/H)
82+ key_BSHDh = op .Reshape (key , pattern .ANY_VALUE , _outputs = ["key_BSHDh" ])
83+
84+ # Possible Transpose patterns for key:
85+ # This scenario optimizes the need for a double transpose
86+ # 1. (B, S, H, D/H) -> (B, H, D/H, S)
87+ # Patterns with double transpose of key
88+ # Double transpose should handle this optimization
89+ # 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
90+ # Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
91+ # 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
92+ key_BHSDh = op .Transpose (key_BSHDh , perm = key_perm )
93+
94+ # Reshape from (B, S, D) to (B, S, H, D/H)
95+ value_BSHDh = op .Reshape (value , pattern .ANY_VALUE , _outputs = ["value_BSHDh" ])
96+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
97+ value_BHSDh = op .Transpose (value_BSHDh , perm = [0 , 2 , 1 , 3 ])
98+ else :
99+ # For cross-attention, key and value are not reshaped
100+ key_BHSDh = key
101+ value_BHSDh = value
103102
104103 if self ._is_rotary :
105104 # This is workaround for examples where there is a duplication of Unsqueeze op
@@ -117,9 +116,12 @@ def pattern(
117116 query_BHSDh_emb = op .RotaryEmbedding (
118117 query_BHSDh , position_ids_q , cos , sin , _domain = "com.microsoft"
119118 )
120- key_BHSDh_emb = op .RotaryEmbedding (
121- key_BHSDh , position_ids_k , cos , sin , _domain = "com.microsoft"
122- )
119+ if not self ._is_cross_attention :
120+ key_BHSDh_emb = op .RotaryEmbedding (
121+ key_BHSDh , position_ids_k , cos , sin , _domain = "com.microsoft"
122+ )
123+ else :
124+ key_BHSDh_emb = key_BHSDh
123125 else :
124126 # If rotary embedding is not used, we fuse with positional_embeddings
125127 query_BHSDh_emb = query_BHSDh
@@ -130,23 +132,13 @@ def pattern(
130132 if self ._has_past_present :
131133 key_seq = op .Concat (past_key , key_BHSDh_emb , axis = - 2 )
132134 else :
133- # For patterns where cross-attention key/value originates from past_key/past_value
134- if self ._is_cross_attention_from_past :
135- key_seq = past_key
136- self ._k_from_past = key_seq
137- else :
138- key_seq = key_BHSDh_emb
135+ key_seq = key_BHSDh_emb
139136
140137 # Concatenate past_value cache and current value
141138 if self ._has_past_present :
142139 value_seq = op .Concat (past_value , value_BHSDh , axis = - 2 )
143140 else :
144- # For patterns where cross-attention key/value originates from past_key/past_value
145- if self ._is_cross_attention_from_past :
146- value_seq = past_value
147- self ._v_from_past = value_seq
148- else :
149- value_seq = value_BHSDh
141+ value_seq = value_BHSDh
150142
151143 # Key/value to be used for dot-product attention computation
152144 key_seq_to_sdpa = key_seq
@@ -198,8 +190,8 @@ def check(
198190 self ,
199191 op ,
200192 query_BSD ,
201- key_BSD ,
202- value_BSD ,
193+ key ,
194+ value ,
203195 mask ,
204196 past_key ,
205197 past_value ,
@@ -221,97 +213,57 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
221213 f"Shape mismatch: { query_BSD } does not match expected dimensions ['B', 'S', 'D']" ,
222214 query_BSD ,
223215 )
224- # If cross-attention key/value originates from past_key/past_value,
225- # Check if their producer is None, this is done to avoid from the matcher assuming
226- # that if a key/value pattern path does not exist, it is a cross-attention pattern.
227- if self ._is_cross_attention_from_past :
228- if self ._k_from_past is not None :
229- if self ._k_from_past .producer () is not None :
230- return check_result .fail (
231- "Key is not from past_key/past_value. This is not a cross-attention pattern." ,
232- )
233- if self ._v_from_past is not None :
234- if self ._v_from_past .producer () is not None :
235- return check_result .fail (
236- "Value is not from past_key/past_value. This is not a cross-attention pattern." ,
237- )
238- # We only consider patterns where,
239- # 1) double_transpose = True, to avoid pattern consuming the transpose of key.
240- # 2) is_rotary = False, as if rotary embeddings are used, the key is not from past_key.
241- # TODO: Determine what parameter conditions would make this pattern full-proof.
242- if not self ._double_transpose or self ._is_rotary :
243- return check_result .fail (
244- "Key is not from past_key/past_value. This is not a cross-attention pattern." ,
245- )
246216
247- """
248- # Check for key transpose values
249- k_perm = _ir_utils.get_singleton_value(key_perm)
250- if k_perm is None or not isinstance(k_perm, list):
251- return check_result.fail(
252- f"Key permutation is not a list.",
253- key_perm,
254- )
255- if len(k_perm) != 4:
217+ if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
256218 return check_result .fail (
257- f"Key permutation is not of length 4. ",
258- key_perm ,
219+ f"Shape mismatch: { query_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh'] " ,
220+ query_BSHDh ,
259221 )
260- if self._double_transpose:
261- if k_perm != [0, 2, 1, 3]:
222+ # If cross-attention key/value shapes are 4D
223+ if self ._is_cross_attention :
224+ if no_match (key , ["B" , "H" , "Skv" , "Dh" ]):
262225 return check_result .fail (
263- f"Key permutation is not [0, 2, 1, 3]. ",
264- key_perm ,
226+ f"Shape mismatch: { key } does not match expected dimensions ['B', 'H', 'Skv', 'Dh'] " ,
227+ key ,
265228 )
266- else:
267- if k_perm != [0, 2, 3, 1]:
229+ if no_match (value , ["B" , "H" , "Skv" , "Dv" ]):
268230 return check_result .fail (
269- f"Key permutation is not [0, 2, 3, 1]. ",
270- key_perm ,
231+ f"Shape mismatch: { value } does not match expected dimensions ['B', 'H', 'Skv', 'Dv'] " ,
232+ value ,
271233 )
272- """
273-
274- if not self ._is_cross_attention_from_past :
275- if no_match (key_BSD , ["B" , "Skv" , "D" ]):
234+ # Ensure that no past_key/past_value is used in cross-attention
235+ if past_key is not None :
276236 return check_result .fail (
277- f"Shape mismatch: { key_BSD } does not match expected dimensions ['B', 'Skv', 'D']" ,
278- query_BSD ,
279- )
280- if no_match (value_BSD , ["B" , "Skv" , "D" ]):
281- return check_result .fail (
282- f"Shape mismatch: { value_BSD } does not match expected dimensions ['B', 'Skv', 'D']" ,
283- value_BSD ,
284- )
285-
286- if self ._has_past_present :
287- if no_match (past_key , ["B" , "H" , "Spast" , "Dh" ]):
288- return check_result .fail (
289- f"Shape mismatch: { past_key } does not match expected dimensions ['B', 'H', 'Spast', 'Dh']" ,
237+ "past_key should be None in cross-attention." ,
290238 past_key ,
291239 )
292- if no_match ( past_value , [ "B" , "H" , "Spast" , "Dv" ]) :
240+ if past_value is not None :
293241 return check_result .fail (
294- f"Shape mismatch: { past_value } does not match expected dimensions ['B', 'H', 'Spast', 'Dv'] " ,
242+ " past_value should be None in cross-attention. " ,
295243 past_value ,
296244 )
297-
298- if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
299- return check_result .fail (
300- f"Shape mismatch: { query_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
301- query_BSHDh ,
302- )
303-
304- if not self ._is_cross_attention_from_past :
305- if key_BSHDh and no_match (key_BSHDh , ["B" , "S" , "H" , "Dh" ]):
245+ else :
246+ if no_match (key , ["B" , "Skv" , "D" ]):
306247 return check_result .fail (
307- f"Shape mismatch: { key_BSHDh } does not match expected dimensions ['B', 'S ', 'H', 'Dh ']" ,
308- query_BSHDh ,
248+ f"Shape mismatch: { key } does not match expected dimensions ['B', 'Skv ', 'D ']" ,
249+ query_BSD ,
309250 )
310- if value_BSHDh and no_match (value_BSHDh , ["B" , "S " , "H" , "Dh " ]):
251+ if no_match (value , ["B" , "Skv " , "D " ]):
311252 return check_result .fail (
312- f"Shape mismatch: { value_BSHDh } does not match expected dimensions ['B', 'S ', 'H', 'Dh ']" ,
313- query_BSHDh ,
253+ f"Shape mismatch: { value } does not match expected dimensions ['B', 'Skv ', 'D ']" ,
254+ value ,
314255 )
256+ if self ._has_past_present :
257+ if no_match (past_key , ["B" , "H" , "Spast" , "Dh" ]):
258+ return check_result .fail (
259+ f"Shape mismatch: { past_key } does not match expected dimensions ['B', 'H', 'Spast', 'Dh']" ,
260+ past_key ,
261+ )
262+ if no_match (past_value , ["B" , "H" , "Spast" , "Dv" ]):
263+ return check_result .fail (
264+ f"Shape mismatch: { past_value } does not match expected dimensions ['B', 'H', 'Spast', 'Dv']" ,
265+ past_value ,
266+ )
315267
316268 # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
317269 # But this also, unforunately, depends on ORT version.
@@ -326,8 +278,8 @@ def rewrite(
326278 self ,
327279 op ,
328280 query_BSD ,
329- key_BSD ,
330- value_BSD ,
281+ key ,
282+ value ,
331283 mask ,
332284 past_key ,
333285 past_value ,
@@ -353,35 +305,21 @@ def rewrite(
353305 query_BSD_emb = op .RotaryEmbedding (
354306 query_BSD , position_ids , cos , sin , _domain = "com.microsoft"
355307 )
356- key_BSD_emb = op .RotaryEmbedding (
357- key_BSD , position_ids , cos , sin , _domain = "com.microsoft"
358- )
308+ if not self ._is_cross_attention :
309+ key_BSD_emb = op .RotaryEmbedding (
310+ key , position_ids , cos , sin , _domain = "com.microsoft"
311+ )
312+ else :
313+ key_BSD_emb = key
359314 else :
360315 query_BSD_emb = query_BSD
361- key_BSD_emb = key_BSD
316+ key_BSD_emb = key
362317
363318 num_outputs = 1 + (2 * self ._has_past_present )
364- # Special case for cross-attention that comes from past_key/past_value
365- if self ._is_cross_attention_from_past :
366- return op .MultiHeadAttention (
367- query_BSD_emb ,
368- past_key ,
369- past_value ,
370- None , # bias
371- None , # key padding mask
372- mask , # attention mask/bias
373- None ,
374- None ,
375- num_heads = num_heads ,
376- scale = scale ,
377- _domain = "com.microsoft" ,
378- _outputs = num_outputs ,
379- )
380-
381319 return op .MultiHeadAttention (
382320 query_BSD_emb ,
383321 key_BSD_emb ,
384- value_BSD ,
322+ value ,
385323 None , # bias
386324 None , # key padding mask
387325 mask , # attention mask/bias
@@ -402,7 +340,7 @@ def rewrite(
402340 "is_rotary" : is_rotary ,
403341 "use_mask" : use_mask ,
404342 "has_past_present" : has_past_present ,
405- "is_cross_attention_from_past " : is_cross_attention_from_past ,
343+ "is_cross_attention " : is_cross_attention ,
406344 }
407345 for double_transpose in [False , True ]
408346 for transpose_4d in (
@@ -411,9 +349,9 @@ def rewrite(
411349 for pre_scale_q in [True , False ]
412350 for is_rotary in [False , True ]
413351 for use_mask in [False , True ]
414- # TODO: Avoid this parameter from being order dependent
352+ # Enforce has_past_present to be True first, to avoid missing the pattern
415353 for has_past_present in [True , False ]
416- for is_cross_attention_from_past in [False , True ]
354+ for is_cross_attention in [False , True ]
417355]
418356
419357# Dynamically create the rules
@@ -426,7 +364,7 @@ def rewrite(
426364 f"{ '_Rotary' if params ['is_rotary' ] else '' } "
427365 f"{ '_Masked' if params ['use_mask' ] else '' } "
428366 f"{ '_Past' if params ['has_past_present' ] else '' } "
429- f"{ '_CrossAttentionFromPast ' if params ['is_cross_attention_from_past ' ] else '' } " ,
367+ f"{ '_CrossAttention ' if params ['is_cross_attention ' ] else '' } " ,
430368 ** params ,
431369 )
432370 for params in parameter_combinations
0 commit comments