Skip to content

Commit f2fa786

Browse files
authored
pytest fixtures for test data loading (#39)
1 parent 5b89eea commit f2fa786

11 files changed

Lines changed: 75 additions & 155 deletions

examples/cgcnn-example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def main( # noqa: C901
4848
weight_decay=1e-6,
4949
batch_size=128,
5050
workers=0,
51-
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
51+
device="cuda" if torch.cuda.is_available() else "cpu",
5252
**kwargs,
5353
):
5454

examples/roost-example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def main( # noqa: C901
4242
weight_decay=1e-6,
4343
batch_size=128,
4444
workers=0,
45-
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
45+
device="cuda" if torch.cuda.is_available() else "cpu",
4646
**kwargs,
4747
):
4848
if not len(targets) == len(tasks) == len(losses):

examples/wren-example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def main( # noqa: C901
4444
weight_decay=1e-6,
4545
batch_size=128,
4646
workers=0,
47-
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
47+
device="cuda" if torch.cuda.is_available() else "cpu",
4848
**kwargs,
4949
):
5050

tests/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
3+
import pytest
4+
import torch
5+
from matminer.datasets import load_dataset
6+
7+
from aviary.cgcnn.utils import get_cgcnn_input
8+
from aviary.wren.utils import get_aflow_label_spglib
9+
10+
torch.manual_seed(0) # ensure reproducible results (applies to all tests)
11+
12+
13+
@pytest.fixture(scope="session")
14+
def df_matbench_phonons():
15+
"""Return a pandas dataframe with the data from the Matbench phonons dataset."""
16+
17+
df = load_dataset("matbench_phonons")
18+
df[["lattice", "sites"]] = [get_cgcnn_input(x) for x in df.structure]
19+
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
20+
df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure]
21+
22+
df["phdos_clf"] = [1 if x > 450 else 0 for x in df["last phdos peak"]]
23+
24+
return df
25+
26+
27+
@pytest.fixture(scope="session")
28+
def df_matbench_phonons_wyckoff(df_matbench_phonons):
29+
"""Getting Aflow labels is expensive so we split into a separate fixture to avoid
30+
paying for it unless needed.
31+
"""
32+
df_matbench_phonons["wyckoff"] = [
33+
get_aflow_label_spglib(x) for x in df_matbench_phonons.structure
34+
]
35+
36+
return df_matbench_phonons
37+
38+
39+
@pytest.fixture(scope="session")
40+
def tests_dir():
41+
return os.path.dirname(os.path.abspath(__file__))

tests/test_cgcnn_classification.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,15 @@
22

33
import numpy as np
44
import torch
5-
from matminer.utils.io import load_dataframe_from_json
65
from sklearn.metrics import accuracy_score, roc_auc_score
76
from sklearn.model_selection import train_test_split as split
87

98
from aviary.cgcnn.data import CrystalGraphData, collate_batch
109
from aviary.cgcnn.model import CrystalGraphConvNet
11-
from aviary.cgcnn.utils import get_cgcnn_input
1210
from aviary.utils import results_multitask, train_ensemble
1311

14-
torch.manual_seed(0) # ensure reproducible results
1512

16-
17-
def test_cgcnn_clf():
18-
data_path = os.path.join(
19-
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
20-
)
13+
def test_cgcnn_clf(df_matbench_phonons):
2114
elem_emb = "cgcnn92"
2215
targets = ["phdos_clf"]
2316
tasks = ["classification"]
@@ -44,26 +37,14 @@ def test_cgcnn_clf():
4437
weight_decay = 1e-6
4538
batch_size = 128
4639
workers = 0
47-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
40+
device = "cuda" if torch.cuda.is_available() else "cpu"
4841

4942
task_dict = dict(zip(targets, tasks))
5043
loss_dict = dict(zip(targets, losses))
5144

52-
assert os.path.exists(data_path), f"{data_path} does not exist!"
53-
54-
df = load_dataframe_from_json(data_path)
55-
df["lattice"] = [None] * len(df)
56-
df["sites"] = [None] * len(df)
57-
df[["lattice", "sites"]] = df.apply(
58-
lambda x: get_cgcnn_input(x.structure), axis=1, result_type="expand"
59-
)
60-
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
61-
df["composition"] = df.structure.apply(
62-
lambda x: x.composition.formula.replace(" ", "")
45+
dataset = CrystalGraphData(
46+
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
6347
)
64-
df["phdos_clf"] = np.where((df["last phdos peak"] > 450), 1, 0)
65-
66-
dataset = CrystalGraphData(df=df, elem_emb=elem_emb, task_dict=task_dict)
6748
n_targets = dataset.n_targets
6849
elem_emb_len = dataset.elem_emb_len
6950
nbr_fea_len = dataset.nbr_fea_dim
@@ -166,7 +147,3 @@ def test_cgcnn_clf():
166147

167148
assert ens_acc > 0.85
168149
assert ens_roc_auc > 0.9
169-
170-
171-
if __name__ == "__main__":
172-
test_cgcnn_clf()

tests/test_cgcnn_regression.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,15 @@
22

33
import numpy as np
44
import torch
5-
from matminer.utils.io import load_dataframe_from_json
65
from sklearn.metrics import r2_score
76
from sklearn.model_selection import train_test_split as split
87

98
from aviary.cgcnn.data import CrystalGraphData, collate_batch
109
from aviary.cgcnn.model import CrystalGraphConvNet
11-
from aviary.cgcnn.utils import get_cgcnn_input
1210
from aviary.utils import results_multitask, train_ensemble
1311

14-
torch.manual_seed(0) # ensure reproducible results
1512

16-
17-
def test_cgcnn_regression():
18-
data_path = os.path.join(
19-
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
20-
)
13+
def test_cgcnn_regression(df_matbench_phonons):
2114
elem_emb = "cgcnn92"
2215
targets = ["last phdos peak"]
2316
tasks = ["regression"]
@@ -44,25 +37,14 @@ def test_cgcnn_regression():
4437
weight_decay = 1e-6
4538
batch_size = 128
4639
workers = 0
47-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
40+
device = "cuda" if torch.cuda.is_available() else "cpu"
4841

4942
task_dict = dict(zip(targets, tasks))
5043
loss_dict = dict(zip(targets, losses))
5144

52-
assert os.path.exists(data_path), f"{data_path} does not exist!"
53-
54-
df = load_dataframe_from_json(data_path)
55-
df["lattice"] = [None] * len(df)
56-
df["sites"] = [None] * len(df)
57-
df[["lattice", "sites"]] = df.apply(
58-
lambda x: get_cgcnn_input(x.structure), axis=1, result_type="expand"
59-
)
60-
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
61-
df["composition"] = df.structure.apply(
62-
lambda x: x.composition.formula.replace(" ", "")
45+
dataset = CrystalGraphData(
46+
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
6347
)
64-
65-
dataset = CrystalGraphData(df=df, elem_emb=elem_emb, task_dict=task_dict)
6648
n_targets = dataset.n_targets
6749
elem_emb_len = dataset.elem_emb_len
6850
nbr_fea_len = dataset.nbr_fea_dim
@@ -164,7 +146,3 @@ def test_cgcnn_regression():
164146
assert r2 > 0.7
165147
assert mae < 150
166148
assert rmse < 300
167-
168-
169-
if __name__ == "__main__":
170-
test_cgcnn_regression()

tests/test_roost_classification.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,15 @@
22

33
import numpy as np
44
import torch
5-
from matminer.utils.io import load_dataframe_from_json
65
from sklearn.metrics import accuracy_score, roc_auc_score
76
from sklearn.model_selection import train_test_split as split
87

98
from aviary.roost.data import CompositionData, collate_batch
109
from aviary.roost.model import Roost
1110
from aviary.utils import results_multitask, train_ensemble
1211

13-
torch.manual_seed(0) # ensure reproducible results
1412

15-
16-
def test_roost_clf():
17-
data_path = os.path.join(
18-
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
19-
)
13+
def test_roost_clf(df_matbench_phonons):
2014
elem_emb = "matscholar200"
2115
targets = ["phdos_clf"]
2216
tasks = ["classification"]
@@ -41,21 +35,14 @@ def test_roost_clf():
4135
weight_decay = 1e-6
4236
batch_size = 128
4337
workers = 0
44-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
38+
device = "cuda" if torch.cuda.is_available() else "cpu"
4539

4640
task_dict = dict(zip(targets, tasks))
4741
loss_dict = dict(zip(targets, losses))
4842

49-
assert os.path.exists(data_path), f"{data_path} does not exist!"
50-
51-
df = load_dataframe_from_json(data_path)
52-
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
53-
df["composition"] = df.structure.apply(
54-
lambda x: x.composition.formula.replace(" ", "")
43+
dataset = CompositionData(
44+
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
5545
)
56-
df["phdos_clf"] = np.where((df["last phdos peak"] > 450), 1, 0)
57-
58-
dataset = CompositionData(df=df, elem_emb=elem_emb, task_dict=task_dict)
5946
n_targets = dataset.n_targets
6047
elem_emb_len = dataset.elem_emb_len
6148

@@ -162,7 +149,3 @@ def test_roost_clf():
162149

163150
assert ens_acc > 0.9
164151
assert ens_roc_auc > 0.9
165-
166-
167-
if __name__ == "__main__":
168-
test_roost_clf()

tests/test_roost_regression.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,15 @@
22

33
import numpy as np
44
import torch
5-
from matminer.utils.io import load_dataframe_from_json
65
from sklearn.metrics import r2_score
76
from sklearn.model_selection import train_test_split as split
87

98
from aviary.roost.data import CompositionData, collate_batch
109
from aviary.roost.model import Roost
1110
from aviary.utils import results_multitask, train_ensemble
1211

13-
torch.manual_seed(0) # ensure reproducible results
1412

15-
16-
def test_roost_regression():
17-
data_path = os.path.join(
18-
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
19-
)
13+
def test_roost_regression(df_matbench_phonons):
2014
elem_emb = "matscholar200"
2115
targets = ["last phdos peak"]
2216
tasks = ["regression"]
@@ -41,20 +35,14 @@ def test_roost_regression():
4135
weight_decay = 1e-6
4236
batch_size = 128
4337
workers = 0
44-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
38+
device = "cuda" if torch.cuda.is_available() else "cpu"
4539

4640
task_dict = dict(zip(targets, tasks))
4741
loss_dict = dict(zip(targets, losses))
4842

49-
assert os.path.exists(data_path), f"{data_path} does not exist!"
50-
51-
df = load_dataframe_from_json(data_path)
52-
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
53-
df["composition"] = df.structure.apply(
54-
lambda x: x.composition.formula.replace(" ", "")
43+
dataset = CompositionData(
44+
df=df_matbench_phonons, elem_emb=elem_emb, task_dict=task_dict
5545
)
56-
57-
dataset = CompositionData(df=df, elem_emb=elem_emb, task_dict=task_dict)
5846
n_targets = dataset.n_targets
5947
elem_emb_len = dataset.elem_emb_len
6048

@@ -160,7 +148,3 @@ def test_roost_regression():
160148
assert r2 > 0.7
161149
assert mae < 150
162150
assert rmse < 300
163-
164-
165-
if __name__ == "__main__":
166-
test_roost_regression()

tests/test_wren_classification.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,15 @@
22

33
import numpy as np
44
import torch
5-
from matminer.utils.io import load_dataframe_from_json
65
from sklearn.metrics import accuracy_score, roc_auc_score
76
from sklearn.model_selection import train_test_split as split
87

98
from aviary.utils import results_multitask, train_ensemble
109
from aviary.wren.data import WyckoffData, collate_batch
1110
from aviary.wren.model import Wren
12-
from aviary.wren.utils import get_aflow_label_spglib
1311

14-
torch.manual_seed(0) # ensure reproducible results
1512

16-
17-
def test_wren_clf():
18-
data_path = os.path.join(
19-
os.path.dirname(os.path.abspath(__file__)), "data/matbench_phonons.json.gz"
20-
)
13+
def test_wren_clf(df_matbench_phonons_wyckoff):
2114
elem_emb = "matscholar200"
2215
sym_emb = "bra-alg-off"
2316
targets = ["phdos_clf"]
@@ -44,23 +37,16 @@ def test_wren_clf():
4437
weight_decay = 1e-6
4538
batch_size = 128
4639
workers = 0
47-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
40+
device = "cuda" if torch.cuda.is_available() else "cpu"
4841

4942
task_dict = dict(zip(targets, tasks))
5043
loss_dict = dict(zip(targets, losses))
5144

52-
assert os.path.exists(data_path), f"{data_path} does not exist!"
53-
54-
df = load_dataframe_from_json(data_path)
55-
df["wyckoff"] = df.structure.apply(get_aflow_label_spglib)
56-
df["material_id"] = [f"mb_phdos_{i}" for i in range(len(df))]
57-
df["composition"] = df.structure.apply(
58-
lambda x: x.composition.formula.replace(" ", "")
59-
)
60-
df["phdos_clf"] = np.where((df["last phdos peak"] > 450), 1, 0)
61-
6245
dataset = WyckoffData(
63-
df=df, elem_emb=elem_emb, sym_emb=sym_emb, task_dict=task_dict
46+
df=df_matbench_phonons_wyckoff,
47+
elem_emb=elem_emb,
48+
sym_emb=sym_emb,
49+
task_dict=task_dict,
6450
)
6551
n_targets = dataset.n_targets
6652
elem_emb_len = dataset.elem_emb_len
@@ -171,7 +157,3 @@ def test_wren_clf():
171157

172158
assert ens_acc > 0.85
173159
assert ens_roc_auc > 0.9
174-
175-
176-
if __name__ == "__main__":
177-
test_wren_clf()

0 commit comments

Comments
 (0)