@@ -62,6 +62,83 @@ def test_sdpa_mask_recent_torch(self):
6262 got = patched_sdpa_mask_recent_torch (** kwargs )
6363 self .assertEqualArray (expected , got )
6464
65+ @requires_transformers ("4.99" )
66+ def test_sdpa_mask_recent_torch_is_running (self ):
67+ def _copy_vmap_for_bhqkv (mask_function , bh_indices = True ):
68+ dimensions = [(None , None , None , 0 ), (None , None , 0 , None )]
69+ if bh_indices :
70+ dimensions .extend ([(None , 0 , None , None ), (0 , None , None , None )])
71+ for dims in dimensions :
72+ mask_function = torch .vmap (mask_function , in_dims = dims , out_dims = 0 )
73+ return mask_function
74+
75+ def copy_of_sdpa_mask_recent_torch (
76+ batch_size ,
77+ cache_position ,
78+ kv_length ,
79+ kv_offset = 0 ,
80+ mask_function = transformers .masking_utils .causal_mask_function ,
81+ attention_mask = None ,
82+ local_size = None ,
83+ allow_is_causal_skip = True ,
84+ ** kwargs ,
85+ ):
86+ q_length = cache_position .shape [0 ]
87+ padding_mask = transformers .masking_utils .prepare_padding_mask (
88+ attention_mask , kv_length , kv_offset
89+ )
90+ if allow_is_causal_skip and transformers .masking_utils ._ignore_causal_mask_sdpa (
91+ padding_mask , q_length , kv_length , kv_offset , local_size
92+ ):
93+ return None
94+ kv_arange = torch .arange (kv_length , device = cache_position .device )
95+ kv_arange += kv_offset
96+ if padding_mask is not None :
97+ mask_function = transformers .masking_utils .and_masks (
98+ mask_function ,
99+ transformers .masking_utils .padding_mask_function (padding_mask ),
100+ )
101+
102+ batch_arange = torch .arange (batch_size , device = cache_position .device )
103+ head_arange = torch .arange (1 , device = cache_position .device )
104+ with transformers .masking_utils .TransformGetItemToIndex ():
105+ causal_mask = _copy_vmap_for_bhqkv (mask_function )(
106+ batch_arange , head_arange , cache_position , kv_arange
107+ )
108+ return causal_mask
109+
110+ sdpa_mask_recent_torch = copy_of_sdpa_mask_recent_torch
111+ patched_sdpa_mask_recent_torch = patch_transformers .patched_sdpa_mask_recent_torch
112+ kwargs = {
113+ "batch_size" : 1 ,
114+ "cache_position" : torch .tensor ([3 ], dtype = torch .int64 ),
115+ "kv_length" : 4 ,
116+ "kv_offset" : 0 ,
117+ "mask_function" : transformers .masking_utils .causal_mask_function ,
118+ "attention_mask" : torch .tensor ([[True , True , True , True ]]),
119+ "local_size" : None ,
120+ "allow_is_causal_skip" : True ,
121+ "allow_is_bidirectional_skip" : False ,
122+ }
123+ expected = sdpa_mask_recent_torch (** kwargs )
124+ got = patched_sdpa_mask_recent_torch (** kwargs )
125+ self .assertEqual (expected , got )
126+
127+ kwargs = {
128+ "batch_size" : 1 ,
129+ "cache_position" : torch .tensor ([3 ], dtype = torch .int64 ),
130+ "kv_length" : 4 ,
131+ "kv_offset" : 0 ,
132+ "mask_function" : transformers .masking_utils .causal_mask_function ,
133+ "attention_mask" : torch .tensor ([[True , True , True , True ]]),
134+ "local_size" : None ,
135+ "allow_is_causal_skip" : False ,
136+ "allow_is_bidirectional_skip" : False ,
137+ }
138+ expected = sdpa_mask_recent_torch (** kwargs )
139+ got = patched_sdpa_mask_recent_torch (** kwargs )
140+ self .assertEqualArray (expected , got )
141+
65142 def test_sdpa_attention_forward_not_causal (self ):
66143 sdpa_attention_forward = sdpa_attention .sdpa_attention_forward
67144 patched_sdpa_attention_forward = patch_transformers .patched_sdpa_attention_forward
0 commit comments