Skip to content

Commit 4b48f59

Browse files
committed
add total std calculation and kwarg warn_target_mismatch to make_ensemble_predictions()
1 parent 0121faa commit 4b48f59

2 files changed

Lines changed: 41 additions & 9 deletions

File tree

aviary/wrenformer/utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def make_ensemble_predictions(
102102
model_class: type[BaseModelClass] = Wrenformer,
103103
device: str = None,
104104
print_metrics: bool = True,
105+
warn_target_mismatch: bool = False,
105106
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
106107
"""Make predictions using an ensemble of Wrenformer models.
107108
@@ -117,11 +118,15 @@ def make_ensemble_predictions(
117118
else "cpu".
118119
print_metrics (bool, optional): Whether to print performance metrics. Defaults to True
119120
if target_col is not None.
121+
warn_target_mismatch (bool, optional): Whether to warn if target_col != target_name from
122+
model checkpoint. Defaults to False.
120123
121124
Returns:
122125
pd.DataFrame: Input dataframe with added columns for model and ensemble predictions. If
123126
target_col is not None, returns a 2nd dataframe containing model and ensemble metrics.
124127
"""
128+
# TODO: Add support for predicting all tasks a multi-task models was trained on. Currently only
129+
# handles single targets.
125130
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
126131

127132
data_loader = df_to_in_mem_dataloader(
@@ -138,11 +143,12 @@ def make_ensemble_predictions(
138143
checkpoint = torch.load(checkpoint_path, map_location=device)
139144

140145
model_params = checkpoint["model_params"]
141-
target_name, task_type = next(model_params["task_dict"].items())
146+
target_name, task_type = list(model_params["task_dict"].items())[0]
142147
assert task_type in ("regression", "classification"), f"invalid {task_type = }"
143-
if target_name != target_col:
148+
if target_name != target_col and warn_target_mismatch:
144149
print(
145-
f"Warning: {target_col = } does not match {target_name = } in checkpoint."
150+
f"Warning: {target_col = } does not match {target_name = } in checkpoint. "
151+
"If this is not by accident, disable this warning by passing warn_target=False."
146152
)
147153
model = model_class(**model_params)
148154
model.to(device)
@@ -155,15 +161,22 @@ def make_ensemble_predictions(
155161
if model.robust:
156162
predictions, aleat_log_std = predictions.chunk(2, dim=1)
157163
aleat_std = aleat_log_std.exp().cpu().numpy().squeeze()
158-
df[f"aleat_std_{idx}"] = aleat_std.tolist()
164+
df[f"aleatoric_std_{idx}"] = aleat_std.tolist()
159165

160166
predictions = predictions.cpu().numpy().squeeze()
161167
pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"
162168
df[pred_col] = predictions.tolist()
163169

164170
df_preds = df.filter(regex=r"_pred_\d")
165171
df[f"{target_col}_pred_ens"] = ensemble_preds = df_preds.mean(axis=1)
166-
df[f"{target_col}_ens_epistemic_std"] = df_preds.std(axis=1)
172+
df[f"{target_col}_epistemic_std_ens"] = epistemic_std = df_preds.std(axis=1)
173+
174+
if df.columns.str.startswith("aleatoric_std_").sum() > 0:
175+
aleatoric_std = df.filter(regex=r"aleatoric_std_\d").mean(axis=1)
176+
df[f"{target_col}_aleatoric_std_ens"] = aleatoric_std
177+
df[f"{target_col}_total_std_ens"] = (
178+
epistemic_std**2 + aleatoric_std**2
179+
) ** 0.5
167180

168181
if target_col and print_metrics:
169182
targets = df[target_col]
@@ -175,12 +188,12 @@ def make_ensemble_predictions(
175188
index=df_preds.columns,
176189
)
177190

178-
print("Single model performance:")
179-
print(all_model_metrics.describe().loc[["mean", "std"]])
191+
print("\nSingle model performance:")
192+
print(all_model_metrics.describe().round(4).loc[["mean", "std"]])
180193

181194
ensemble_metrics = get_metrics(targets, ensemble_preds, task_type)
182195

183-
print("Ensemble performance:")
196+
print("\nEnsemble performance:")
184197
for key, val in ensemble_metrics.items():
185198
print(f"{key:<8} {val:.3}")
186199
return df, all_model_metrics

examples/mp_wbm/use_trained_wrenformer_ensemble.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
the MP+WBM dataset and makes predictions on the test set, then prints ensemble metrics.
1717
"""
1818

19+
1920
data_path = f"{ROOT}/datasets/2022-06-09-mp+wbm.json.gz"
2021
target_col = "e_form"
2122
test_size = 0.05
@@ -34,6 +35,10 @@
3435

3536
runs = wandb_api.runs("aviary/mp-wbm", filters={"tags": {"$in": ["ensemble-id-2"]}})
3637

38+
print(
39+
f"Loading checkpoints for the following run IDs:\n{', '.join(run.id for run in runs)}\n"
40+
)
41+
3742
checkpoint_paths: list[str] = []
3843
for run in runs:
3944
run_path = "/".join(run.path)
@@ -59,4 +64,18 @@
5964
checkpoint_paths, df=test_df, target_col=target_col
6065
)
6166

62-
test_df.to_csv(f"{ROOT}/examples/mp_wbm/ensemble-test-predictions.csv")
67+
test_df.to_csv(f"{ROOT}/examples/mp_wbm/ensemble-predictions.csv")
68+
69+
70+
# print output:
71+
# Predicting with 10 model checkpoints(s)
72+
#
73+
# Single model performance:
74+
# MAE RMSE R2
75+
# mean 0.0369 0.1218 0.9864
76+
# std 0.0005 0.0014 0.0003
77+
#
78+
# Ensemble performance:
79+
# MAE 0.0308
80+
# RMSE 0.118
81+
# R2 0.987

0 commit comments

Comments
 (0)