1+ from __future__ import annotations
2+
13import json
24import time
35from contextlib import contextmanager
46from typing import Generator , Literal
57
8+ import pandas as pd
9+ import torch
10+ from tqdm import tqdm
11+
12+ from aviary .core import BaseModelClass
13+ from aviary .utils import get_metrics
14+ from aviary .wrenformer .data import df_to_in_mem_dataloader
15+ from aviary .wrenformer .model import Wrenformer
16+
17+ __author__ = "Janosh Riebesell"
18+ __date__ = "2022-05-10"
19+
620
721def _int_keys (dct : dict ) -> dict :
822 # JSON stringifies all dict keys during serialization and does not revert
@@ -45,14 +59,14 @@ def merge_json_on_disk(
4559 pass
4660
4761 def non_serializable_handler (obj : object ) -> str :
48- # replace functions and classes in dct with string indicating a non-serializable type
62+ # replace functions and classes in dct with string indicating it's a non-serializable type
4963 return f"<not serializable: { type (obj ).__qualname__ } >"
5064
5165 with open (file_path , "w" ) as file :
5266 default = (
5367 non_serializable_handler if on_non_serializable == "annotate" else None
5468 )
55- json .dump (dct , file , default = default )
69+ json .dump (dct , file , default = default , indent = 2 )
5670
5771
5872@contextmanager
@@ -78,3 +92,110 @@ def print_walltime(
7892 finally :
7993 run_time = time .perf_counter () - start_time
8094 print (f"{ end_desc } took { run_time :.2f} sec" )
95+
96+
97+ def make_ensemble_predictions (
98+ checkpoint_paths : list [str ],
99+ df : pd .DataFrame ,
100+ target_col : str = None ,
101+ input_col : str = "wyckoff" ,
102+ model_class : type [BaseModelClass ] = Wrenformer ,
103+ device : str = None ,
104+ print_metrics : bool = True ,
105+ warn_target_mismatch : bool = False ,
106+ ) -> pd .DataFrame | tuple [pd .DataFrame , pd .DataFrame ]:
107+ """Make predictions using an ensemble of Wrenformer models.
108+
109+ Args:
110+ checkpoint_paths (list[str]): File paths to model checkpoints created with torch.save().
111+ df (pd.DataFrame): Dataframe to make predictions on. Will be returned with additional
112+ columns holding model predictions (and uncertainties for robust models) for each
113+ model checkpoint.
114+ target_col (str): Column holding target values. Defaults to None. If None, will not print
115+ performance metrics.
116+ input_col (str, optional): Column holding input values. Defaults to 'wyckoff'.
117+ device (str, optional): torch.device. Defaults to "cuda" if torch.cuda.is_available()
118+ else "cpu".
119+ print_metrics (bool, optional): Whether to print performance metrics. Defaults to True
120+ if target_col is not None.
121+ warn_target_mismatch (bool, optional): Whether to warn if target_col != target_name from
122+ model checkpoint. Defaults to False.
123+
124+ Returns:
125+ pd.DataFrame: Input dataframe with added columns for model and ensemble predictions. If
126+ target_col is not None, returns a 2nd dataframe containing model and ensemble metrics.
127+ """
128+ # TODO: Add support for predicting all tasks a multi-task models was trained on. Currently only
129+ # handles single targets.
130+ device = device or ("cuda" if torch .cuda .is_available () else "cpu" )
131+
132+ data_loader = df_to_in_mem_dataloader (
133+ df = df ,
134+ target_col = target_col ,
135+ input_col = input_col ,
136+ batch_size = 512 ,
137+ embedding_type = "wyckoff" ,
138+ )
139+
140+ print (f"Predicting with { len (checkpoint_paths ):,} model checkpoints(s)" )
141+
142+ for idx , checkpoint_path in enumerate (tqdm (checkpoint_paths ), 1 ):
143+ checkpoint = torch .load (checkpoint_path , map_location = device )
144+
145+ model_params = checkpoint ["model_params" ]
146+ target_name , task_type = list (model_params ["task_dict" ].items ())[0 ]
147+ assert task_type in ("regression" , "classification" ), f"invalid { task_type = } "
148+ if target_name != target_col and warn_target_mismatch :
149+ print (
150+ f"Warning: { target_col = } does not match { target_name = } in checkpoint. "
151+ "If this is not by accident, disable this warning by passing warn_target=False."
152+ )
153+ model = model_class (** model_params )
154+ model .to (device )
155+
156+ model .load_state_dict (checkpoint ["model_state" ])
157+
158+ with torch .no_grad ():
159+ predictions = torch .cat ([model (* inputs )[0 ] for inputs , * _ in data_loader ])
160+
161+ if model .robust :
162+ predictions , aleat_log_std = predictions .chunk (2 , dim = 1 )
163+ aleat_std = aleat_log_std .exp ().cpu ().numpy ().squeeze ()
164+ df [f"aleatoric_std_{ idx } " ] = aleat_std .tolist ()
165+
166+ predictions = predictions .cpu ().numpy ().squeeze ()
167+ pred_col = f"{ target_col } _pred_{ idx } " if target_col else f"pred_{ idx } "
168+ df [pred_col ] = predictions .tolist ()
169+
170+ df_preds = df .filter (regex = r"_pred_\d" )
171+ df [f"{ target_col } _pred_ens" ] = ensemble_preds = df_preds .mean (axis = 1 )
172+ df [f"{ target_col } _epistemic_std_ens" ] = epistemic_std = df_preds .std (axis = 1 )
173+
174+ if df .columns .str .startswith ("aleatoric_std_" ).sum () > 0 :
175+ aleatoric_std = df .filter (regex = r"aleatoric_std_\d" ).mean (axis = 1 )
176+ df [f"{ target_col } _aleatoric_std_ens" ] = aleatoric_std
177+ df [f"{ target_col } _total_std_ens" ] = (
178+ epistemic_std ** 2 + aleatoric_std ** 2
179+ ) ** 0.5
180+
181+ if target_col and print_metrics :
182+ targets = df [target_col ]
183+ all_model_metrics = pd .DataFrame (
184+ [
185+ get_metrics (targets , df_preds [pred_col ], task_type )
186+ for pred_col in df_preds
187+ ],
188+ index = df_preds .columns ,
189+ )
190+
191+ print ("\n Single model performance:" )
192+ print (all_model_metrics .describe ().round (4 ).loc [["mean" , "std" ]])
193+
194+ ensemble_metrics = get_metrics (targets , ensemble_preds , task_type )
195+
196+ print ("\n Ensemble performance:" )
197+ for key , val in ensemble_metrics .items ():
198+ print (f"{ key :<8} { val :.3} " )
199+ return df , all_model_metrics
200+
201+ return df
0 commit comments