We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7b4e562 commit be8da13Copy full SHA for be8da13
1 file changed
recipes/esm2_native_te_mfsdp/test_train.py
@@ -28,10 +28,12 @@
28
from train_mfsdp import main as main_mfsdp
29
30
31
-random.seed(42)
32
-torch.manual_seed(42)
33
-if torch.cuda.is_available():
34
- torch.cuda.manual_seed_all(42)
+@pytest.fixture(autouse=True)
+def set_seed():
+ random.seed(42)
+ torch.manual_seed(42)
35
+ if torch.cuda.is_available():
36
+ torch.cuda.manual_seed_all(42)
37
38
39
requires_multi_gpu = pytest.mark.skipif(
0 commit comments