22
33import numpy as np
44import torch
5- from matminer .utils .io import load_dataframe_from_json
65from sklearn .metrics import accuracy_score , roc_auc_score
76from sklearn .model_selection import train_test_split as split
87
98from aviary .cgcnn .data import CrystalGraphData , collate_batch
109from aviary .cgcnn .model import CrystalGraphConvNet
11- from aviary .cgcnn .utils import get_cgcnn_input
1210from aviary .utils import results_multitask , train_ensemble
1311
14- torch .manual_seed (0 ) # ensure reproducible results
1512
16-
17- def test_cgcnn_clf ():
18- data_path = os .path .join (
19- os .path .dirname (os .path .abspath (__file__ )), "data/matbench_phonons.json.gz"
20- )
13+ def test_cgcnn_clf (df_matbench_phonons ):
2114 elem_emb = "cgcnn92"
2215 targets = ["phdos_clf" ]
2316 tasks = ["classification" ]
@@ -44,26 +37,14 @@ def test_cgcnn_clf():
4437 weight_decay = 1e-6
4538 batch_size = 128
4639 workers = 0
47- device = torch . device ( "cuda" ) if torch .cuda .is_available () else torch . device ( "cpu" )
40+ device = "cuda" if torch .cuda .is_available () else "cpu"
4841
4942 task_dict = dict (zip (targets , tasks ))
5043 loss_dict = dict (zip (targets , losses ))
5144
52- assert os .path .exists (data_path ), f"{ data_path } does not exist!"
53-
54- df = load_dataframe_from_json (data_path )
55- df ["lattice" ] = [None ] * len (df )
56- df ["sites" ] = [None ] * len (df )
57- df [["lattice" , "sites" ]] = df .apply (
58- lambda x : get_cgcnn_input (x .structure ), axis = 1 , result_type = "expand"
59- )
60- df ["material_id" ] = [f"mb_phdos_{ i } " for i in range (len (df ))]
61- df ["composition" ] = df .structure .apply (
62- lambda x : x .composition .formula .replace (" " , "" )
45+ dataset = CrystalGraphData (
46+ df = df_matbench_phonons , elem_emb = elem_emb , task_dict = task_dict
6347 )
64- df ["phdos_clf" ] = np .where ((df ["last phdos peak" ] > 450 ), 1 , 0 )
65-
66- dataset = CrystalGraphData (df = df , elem_emb = elem_emb , task_dict = task_dict )
6748 n_targets = dataset .n_targets
6849 elem_emb_len = dataset .elem_emb_len
6950 nbr_fea_len = dataset .nbr_fea_dim
@@ -166,7 +147,3 @@ def test_cgcnn_clf():
166147
167148 assert ens_acc > 0.85
168149 assert ens_roc_auc > 0.9
169-
170-
171- if __name__ == "__main__" :
172- test_cgcnn_clf ()
0 commit comments