Skip to content

Commit f6ff4ed

Browse files
authored
Skip invalid flash-attn tests for pi0 model (#45011)
* skip 2 invalid test cases for pi0 model Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * skip all FA related test cases Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 68abf9b commit f6ff4ed

1 file changed

Lines changed: 64 additions & 0 deletions

File tree

tests/models/pi0/test_modeling_pi0.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,70 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
213213
def test_training_gradient_checkpointing_use_reentrant_true(self):
214214
pass
215215

216+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
217+
def test_flash_attn_2_inference_equivalence(self):
218+
pass
219+
220+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
221+
def test_flash_attn_2_inference_equivalence_right_padding(self):
222+
pass
223+
224+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
225+
def test_flash_attn_3_inference_equivalence(self):
226+
pass
227+
228+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
229+
def test_flash_attn_3_inference_equivalence_right_padding(self):
230+
pass
231+
232+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
233+
def test_flash_attn_4_inference_equivalence(self):
234+
pass
235+
236+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
237+
def test_flash_attn_4_inference_equivalence_right_padding(self):
238+
pass
239+
240+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
241+
def test_flash_attn_kernels_inference_equivalence(self):
242+
pass
243+
244+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
245+
def test_flash_attn_kernels_mps_inference_equivalence(self):
246+
pass
247+
248+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
249+
def test_flash_attn_2_can_dispatch_composite_models(self):
250+
pass
251+
252+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
253+
def test_flash_attn_3_can_dispatch_composite_models(self):
254+
pass
255+
256+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
257+
def test_flash_attn_4_can_dispatch_composite_models(self):
258+
pass
259+
260+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
261+
def test_flash_attn_2_fp32_ln(self):
262+
pass
263+
264+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
265+
def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(self):
266+
pass
267+
268+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
269+
def test_flash_attn_2_from_config(self):
270+
pass
271+
272+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
273+
def test_flash_attn_3_from_config(self):
274+
pass
275+
276+
@unittest.skip("PI0 model requires pixel_attention_mask to be provided")
277+
def test_flash_attn_4_from_config(self):
278+
pass
279+
216280
def test_full_run_smoke(self):
217281
torch.manual_seed(0)
218282
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

0 commit comments

Comments
 (0)