Skip to content

Commit 92f083a

Browse files
committed
fix models/ dir creation
should be the concern of save_checkpoint() who relies on it being there
1 parent f2fa786 commit 92f083a

10 files changed

Lines changed: 2 additions & 30 deletions

aviary/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import gc
4+
import os
45
import shutil
56
import sys
67
from abc import ABC, abstractmethod
@@ -510,6 +511,7 @@ def save_checkpoint(
510511
model_name (str): String describing the model.
511512
run_id (int): Unique identifier of the model run.
512513
"""
514+
os.makedirs(f"models/{model_name}", exist_ok=True)
513515
checkpoint = f"models/{model_name}/checkpoint-r{run_id}.pth.tar"
514516
best = f"models/{model_name}/best-r{run_id}.pth.tar"
515517

examples/cgcnn-example.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,6 @@ def main( # noqa: C901
202202
"n_hidden": n_hidden,
203203
}
204204

205-
os.makedirs(f"models/{model_name}/", exist_ok=True)
206-
207205
if log:
208206
os.makedirs("runs/", exist_ok=True)
209207

examples/roost-example.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,6 @@ def main( # noqa: C901
186186
"out_hidden": [256, 128, 64],
187187
}
188188

189-
os.makedirs(f"models/{model_name}/", exist_ok=True)
190-
191189
if log:
192190
os.makedirs("runs/", exist_ok=True)
193191

examples/wren-example.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,6 @@ def main( # noqa: C901
199199
"trunk_hidden": [128, 64],
200200
}
201201

202-
os.makedirs(f"models/{model_name}/", exist_ok=True)
203-
204202
if log:
205203
os.makedirs("runs/", exist_ok=True)
206204

tests/test_cgcnn_classification.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
import torch
53
from sklearn.metrics import accuracy_score, roc_auc_score
@@ -98,8 +96,6 @@ def test_cgcnn_clf(df_matbench_phonons):
9896
"n_hidden": n_hidden,
9997
}
10098

101-
os.makedirs(f"models/{model_name}", exist_ok=True)
102-
10399
train_ensemble(
104100
model_class=CrystalGraphConvNet,
105101
model_name=model_name,

tests/test_cgcnn_regression.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
import torch
53
from sklearn.metrics import r2_score
@@ -98,8 +96,6 @@ def test_cgcnn_regression(df_matbench_phonons):
9896
"n_hidden": n_hidden,
9997
}
10098

101-
os.makedirs(f"models/{model_name}", exist_ok=True)
102-
10399
train_ensemble(
104100
model_class=CrystalGraphConvNet,
105101
model_name=model_name,

tests/test_roost_classification.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
import torch
53
from sklearn.metrics import accuracy_score, roc_auc_score
@@ -100,8 +98,6 @@ def test_roost_clf(df_matbench_phonons):
10098
"out_hidden": [128, 64],
10199
}
102100

103-
os.makedirs(f"models/{model_name}", exist_ok=True)
104-
105101
train_ensemble(
106102
model_class=Roost,
107103
model_name=model_name,

tests/test_roost_regression.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
import torch
53
from sklearn.metrics import r2_score
@@ -100,8 +98,6 @@ def test_roost_regression(df_matbench_phonons):
10098
"out_hidden": [128, 64],
10199
}
102100

103-
os.makedirs(f"models/{model_name}", exist_ok=True)
104-
105101
train_ensemble(
106102
model_class=Roost,
107103
model_name=model_name,

tests/test_wren_classification.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
import torch
53
from sklearn.metrics import accuracy_score, roc_auc_score
@@ -108,8 +106,6 @@ def test_wren_clf(df_matbench_phonons_wyckoff):
108106
"trunk_hidden": [64],
109107
}
110108

111-
os.makedirs(f"models/{model_name}", exist_ok=True)
112-
113109
train_ensemble(
114110
model_class=Wren,
115111
model_name=model_name,

tests/test_wren_regression.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
import torch
53
from sklearn.metrics import r2_score
@@ -108,8 +106,6 @@ def test_wren_regression(df_matbench_phonons_wyckoff):
108106
"trunk_hidden": [64],
109107
}
110108

111-
os.makedirs(f"models/{model_name}", exist_ok=True)
112-
113109
train_ensemble(
114110
model_class=Wren,
115111
model_name=model_name,

0 commit comments

Comments
 (0)