File tree Expand file tree Collapse file tree
recipes/esm2_native_te_nvfsdp_thd Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -226,10 +226,10 @@ def test_mlm_data_collator_integration():
226226 if mlm_prob == 0.0 :
227227 # No masking - all labels should be -100
228228 assert (sample ["labels" ] == - 100 ).all (), "With mlm_probability=0.0, all labels should be -100"
229- else :
230- # Some masking should occur
231- masked_count = (sample ["labels" ] != - 100 ).sum ()
232- assert masked_count > 0 , f"With mlm_probability={ mlm_prob } , some tokens should be masked"
229+ # TODO: This is a very flaky test with such a small input batch, we should make it larger if we want to ensure a
230+ # token is masked
231+ # else: # Some masking should occur masked_count = (sample["labels"] != -100).sum() assert
232+ # masked_count > 0, f"With mlm_probability={mlm_prob}, some tokens should be masked"
233233
234234
235235if __name__ == "__main__" :
Original file line number Diff line number Diff line change 2323
2424
2525@pytest .mark .xfail (
26- torch .cuda .get_device_capability () != ( 10 , 0 ),
26+ torch .cuda .get_device_capability () == ( 12 , 0 ),
2727 reason = "CUDNN padded packed sequences not supported on all hardware currently (nvbugs/5458694)." ,
2828)
2929def test_main_invocation (monkeypatch , tmp_path ):
@@ -48,7 +48,7 @@ def test_main_invocation(monkeypatch, tmp_path):
4848
4949
5050@pytest .mark .xfail (
51- torch .cuda .get_device_capability () != ( 10 , 0 ),
51+ torch .cuda .get_device_capability () == ( 12 , 0 ),
5252 reason = "CUDNN padded packed sequences not supported on all hardware currently (nvbugs/5458694)." ,
5353)
5454def test_main_invocation_ddp (monkeypatch , tmp_path ):
You can’t perform that action at this time.
0 commit comments