@@ -213,6 +213,70 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
213213 def test_training_gradient_checkpointing_use_reentrant_true (self ):
214214 pass
215215
216+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
217+ def test_flash_attn_2_inference_equivalence (self ):
218+ pass
219+
220+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
221+ def test_flash_attn_2_inference_equivalence_right_padding (self ):
222+ pass
223+
224+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
225+ def test_flash_attn_3_inference_equivalence (self ):
226+ pass
227+
228+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
229+ def test_flash_attn_3_inference_equivalence_right_padding (self ):
230+ pass
231+
232+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
233+ def test_flash_attn_4_inference_equivalence (self ):
234+ pass
235+
236+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
237+ def test_flash_attn_4_inference_equivalence_right_padding (self ):
238+ pass
239+
240+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
241+ def test_flash_attn_kernels_inference_equivalence (self ):
242+ pass
243+
244+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
245+ def test_flash_attn_kernels_mps_inference_equivalence (self ):
246+ pass
247+
248+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
249+ def test_flash_attn_2_can_dispatch_composite_models (self ):
250+ pass
251+
252+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
253+ def test_flash_attn_3_can_dispatch_composite_models (self ):
254+ pass
255+
256+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
257+ def test_flash_attn_4_can_dispatch_composite_models (self ):
258+ pass
259+
260+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
261+ def test_flash_attn_2_fp32_ln (self ):
262+ pass
263+
264+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
265+ def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break (self ):
266+ pass
267+
268+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
269+ def test_flash_attn_2_from_config (self ):
270+ pass
271+
272+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
273+ def test_flash_attn_3_from_config (self ):
274+ pass
275+
276+ @unittest .skip ("PI0 model requires pixel_attention_mask to be provided" )
277+ def test_flash_attn_4_from_config (self ):
278+ pass
279+
216280 def test_full_run_smoke (self ):
217281 torch .manual_seed (0 )
218282 config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
0 commit comments