99
1010from aviary import ROOT
1111from aviary .core import Normalizer , TaskType
12- from aviary .data import InMemoryDataLoader
1312from aviary .losses import RobustL1Loss
1413from aviary .utils import get_metrics
15- from aviary .wrenformer .data import (
16- collate_batch ,
17- get_composition_embedding ,
18- wyckoff_embedding_from_aflow_str ,
19- )
14+ from aviary .wrenformer .data import df_to_in_mem_dataloader
2015from aviary .wrenformer .model import Wrenformer
2116from aviary .wrenformer .utils import print_walltime
2217
@@ -42,6 +37,7 @@ def run_wrenformer(
4237 target_col : str ,
4338 epochs : int ,
4439 timestamp : str = None ,
40+ input_col : str = None ,
4541 id_col : str = "material_id" ,
4642 n_attn_layers : int = 4 ,
4743 wandb_project : str = None ,
@@ -67,6 +63,8 @@ def run_wrenformer(
6763 train_df (pd.DataFrame): Dataframe containing the training data.
6864 test_df (pd.DataFrame): Dataframe containing the test data.
6965 target_col (str): Name of df column containing the target values.
66+ input_col (str): Name of df column containing the input values. Defaults to 'wyckoff' if
67+ 'wren' in run_name else 'composition'.
7068 id_col (str): Name of df column containing material IDs.
7169 epochs (int): How many epochs to train for. Defaults to 100.
7270 timestamp (str): Will be included in run_params and used as file name prefix for model
@@ -118,23 +116,16 @@ def run_wrenformer(
118116 device = "cuda" if torch .cuda .is_available () else "cpu"
119117 print (f"Pytorch running on { device = } " )
120118
121- for label , df in [("training set" , train_df ), ("test set" , test_df )]:
122- if "wren" in run_name .lower ():
123- err_msg = "Missing 'wyckoff' column in dataframe. "
124- err_msg += (
125- "Please generate Aflow Wyckoff labels ahead of time."
126- if "structure" in df
127- else "Trying to deploy Wrenformer on composition-only task?"
128- )
129- assert "wyckoff" in df , err_msg
130- with print_walltime (
131- start_desc = f"Generating Wyckoff embeddings for { label } " , newline = False
132- ):
133- df ["features" ] = df .wyckoff .map (wyckoff_embedding_from_aflow_str )
134- elif "roost" in run_name .lower ():
135- df ["features" ] = df .composition .map (get_composition_embedding )
136- else :
137- raise ValueError (f"{ run_name = } must contain 'roost' or 'wren'" )
119+ if "wren" in run_name .lower ():
120+ input_col = input_col or "wyckoff"
121+ embedding_type = "wyckoff"
122+ elif "roost" in run_name .lower ():
123+ input_col = input_col or "composition"
124+ embedding_type = "composition"
125+ else :
126+ raise ValueError (
127+ f"{ run_name = } must contain 'roost' or 'wren' (case insensitive)"
128+ )
138129
139130 robust = "robust" in run_name .lower ()
140131 loss_func = (
@@ -145,42 +136,24 @@ def run_wrenformer(
145136 loss_dict = {target_col : (task_type , loss_func )}
146137 normalizer_dict = {target_col : Normalizer () if task_type == reg_key else None }
147138
148- features , targets , ids = (train_df [x ] for x in ["features" , target_col , id_col ])
149- targets = torch .tensor (targets , device = device )
150- if targets .dtype == torch .bool :
151- targets = targets .long ()
152- inputs = np .empty (len (features ), dtype = object )
153- for idx , tensor in enumerate (features ):
154- inputs [idx ] = tensor .to (device )
155-
156- train_loader = InMemoryDataLoader (
157- [inputs , targets , ids ],
158- batch_size = batch_size ,
159- shuffle = True ,
160- collate_fn = collate_batch ,
139+ data_loader_kwargs = dict (
140+ target_col = target_col ,
141+ input_col = input_col ,
142+ id_col = id_col ,
143+ embedding_type = embedding_type ,
161144 )
162-
163- features , targets , ids = (test_df [x ] for x in ["features" , target_col , id_col ])
164- targets = torch .tensor (targets , device = device )
165- if targets .dtype == torch .bool :
166- targets = targets .long ()
167- inputs = np .empty (len (features ), dtype = object )
168- for idx , tensor in enumerate (features ):
169- inputs [idx ] = tensor .to (device )
170-
171- test_loader = InMemoryDataLoader (
172- [inputs , targets , ids ], batch_size = 512 , collate_fn = collate_batch
145+ train_loader = df_to_in_mem_dataloader (
146+ train_df , batch_size = batch_size , shuffle = True , ** data_loader_kwargs
173147 )
174148
175- # n_features is the length of the embedding vector for a Wyckoff position encoding
176- # the element type (usually 200-dim matscholar embeddings) and Wyckoff position (see
177- # 'bra-alg-off.json') + 1 for the weight of that element/Wyckoff position in the
178- # material's composition
179- embedding_len = features [0 ].shape [- 1 ]
180- assert embedding_len in (
181- 200 + 1 ,
182- 200 + 1 + 444 ,
183- ) # Roost and Wren embedding size resp.
149+ test_loader = df_to_in_mem_dataloader (test_df , batch_size = 512 , ** data_loader_kwargs )
150+
151+ # embedding_len is the length of the embedding vector for a Wyckoff position encoding the
152+ # element type (usually 200-dim matscholar embeddings) and Wyckoff position (see
153+ # 'bra-alg-off.json') + 1 for the weight of that Wyckoff position (or element) in the material
154+ embedding_len = train_loader .tensors [0 ][0 ].shape [- 1 ]
155+ # Roost and Wren embedding size resp.
156+ assert embedding_len in (200 + 1 , 200 + 1 + 444 ), f"{ embedding_len = } "
184157
185158 model_params = dict (
186159 # 1 for regression, n_classes for classification
@@ -242,6 +215,7 @@ def run_wrenformer(
242215 "training_samples" : len (train_df ),
243216 "test_samples" : len (test_df ),
244217 "trainable_params" : model .num_params ,
218+ "task_type" : task_type ,
245219 "swa" : {
246220 "start" : swa_start ,
247221 "epochs" : int (swa_start * epochs ),
@@ -272,6 +246,7 @@ def run_wrenformer(
272246 for epoch in range (epochs ):
273247 if verbose :
274248 print (f"Epoch { epoch + 1 } /{ epochs } " )
249+
275250 train_metrics = model .evaluate (
276251 train_loader ,
277252 loss_dict ,
@@ -333,9 +308,9 @@ def run_wrenformer(
333308 predictions = predictions .softmax (dim = 1 )
334309
335310 predictions = predictions .cpu ().numpy ().squeeze ()
336- targets = targets . cpu (). numpy ()
311+ targets = test_df [ target_col ]
337312 pred_col = f"{ target_col } _pred"
338- test_df [pred_col ] = predictions .tolist ()
313+ test_df [pred_col ] = predictions .tolist () # requires shuffle=False for test_loader
339314
340315 test_metrics = get_metrics (targets , predictions , task_type )
341316 test_metrics ["test_size" ] = len (test_df )
0 commit comments