@@ -102,6 +102,7 @@ def make_ensemble_predictions(
102102 model_class : type [BaseModelClass ] = Wrenformer ,
103103 device : str = None ,
104104 print_metrics : bool = True ,
105+ warn_target_mismatch : bool = False ,
105106) -> pd .DataFrame | tuple [pd .DataFrame , pd .DataFrame ]:
106107 """Make predictions using an ensemble of Wrenformer models.
107108
@@ -117,11 +118,15 @@ def make_ensemble_predictions(
117118 else "cpu".
118119 print_metrics (bool, optional): Whether to print performance metrics. Defaults to True
119120 if target_col is not None.
121+ warn_target_mismatch (bool, optional): Whether to warn if target_col != target_name from
122+ model checkpoint. Defaults to False.
120123
121124 Returns:
122125 pd.DataFrame: Input dataframe with added columns for model and ensemble predictions. If
123126 target_col is not None, returns a 2nd dataframe containing model and ensemble metrics.
124127 """
128+ # TODO: Add support for predicting all tasks a multi-task models was trained on. Currently only
129+ # handles single targets.
125130 device = device or ("cuda" if torch .cuda .is_available () else "cpu" )
126131
127132 data_loader = df_to_in_mem_dataloader (
@@ -138,11 +143,12 @@ def make_ensemble_predictions(
138143 checkpoint = torch .load (checkpoint_path , map_location = device )
139144
140145 model_params = checkpoint ["model_params" ]
141- target_name , task_type = next (model_params ["task_dict" ].items ())
146+ target_name , task_type = list (model_params ["task_dict" ].items ())[ 0 ]
142147 assert task_type in ("regression" , "classification" ), f"invalid { task_type = } "
143- if target_name != target_col :
148+ if target_name != target_col and warn_target_mismatch :
144149 print (
145- f"Warning: { target_col = } does not match { target_name = } in checkpoint."
150+ f"Warning: { target_col = } does not match { target_name = } in checkpoint. "
151+ "If this is not by accident, disable this warning by passing warn_target=False."
146152 )
147153 model = model_class (** model_params )
148154 model .to (device )
@@ -155,15 +161,22 @@ def make_ensemble_predictions(
155161 if model .robust :
156162 predictions , aleat_log_std = predictions .chunk (2 , dim = 1 )
157163 aleat_std = aleat_log_std .exp ().cpu ().numpy ().squeeze ()
158- df [f"aleat_std_ { idx } " ] = aleat_std .tolist ()
164+ df [f"aleatoric_std_ { idx } " ] = aleat_std .tolist ()
159165
160166 predictions = predictions .cpu ().numpy ().squeeze ()
161167 pred_col = f"{ target_col } _pred_{ idx } " if target_col else f"pred_{ idx } "
162168 df [pred_col ] = predictions .tolist ()
163169
164170 df_preds = df .filter (regex = r"_pred_\d" )
165171 df [f"{ target_col } _pred_ens" ] = ensemble_preds = df_preds .mean (axis = 1 )
166- df [f"{ target_col } _ens_epistemic_std" ] = df_preds .std (axis = 1 )
172+ df [f"{ target_col } _epistemic_std_ens" ] = epistemic_std = df_preds .std (axis = 1 )
173+
174+ if df .columns .str .startswith ("aleatoric_std_" ).sum () > 0 :
175+ aleatoric_std = df .filter (regex = r"aleatoric_std_\d" ).mean (axis = 1 )
176+ df [f"{ target_col } _aleatoric_std_ens" ] = aleatoric_std
177+ df [f"{ target_col } _total_std_ens" ] = (
178+ epistemic_std ** 2 + aleatoric_std ** 2
179+ ) ** 0.5
167180
168181 if target_col and print_metrics :
169182 targets = df [target_col ]
@@ -175,12 +188,12 @@ def make_ensemble_predictions(
175188 index = df_preds .columns ,
176189 )
177190
178- print ("Single model performance:" )
179- print (all_model_metrics .describe ().loc [["mean" , "std" ]])
191+ print ("\n Single model performance:" )
192+ print (all_model_metrics .describe ().round ( 4 ). loc [["mean" , "std" ]])
180193
181194 ensemble_metrics = get_metrics (targets , ensemble_preds , task_type )
182195
183- print ("Ensemble performance:" )
196+ print ("\n Ensemble performance:" )
184197 for key , val in ensemble_metrics .items ():
185198 print (f"{ key :<8} { val :.3} " )
186199 return df , all_model_metrics
0 commit comments