Skip to content

Commit 983a0b8

Browse files
committed
add function df_to_in_mem_dataloader() to extract train/test_loader creation from run_wrenformer()
1 parent 3c100bb commit 983a0b8

4 files changed

Lines changed: 95 additions & 62 deletions

File tree

aviary/data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ class InMemoryDataLoader:
1515
1616
Args:
1717
*tensors: List of arrays or tensors. Must all have the same length in dimension 0.
18-
batch_size (int, optional): Defaults to 32.
18+
batch_size (int, optional): Usually 64, 128 or 256. Can be larger for test set loaders
19+
to speedup inference. Defaults to 64.
1920
shuffle (bool, optional): If True, shuffle the data *in-place* whenever an
2021
iterator is created from this object. Defaults to False.
2122
collate_fn (Callable, optional): Should accept variadic list of tensors and
2223
output a minibatch of data ready for model consumption. Defaults to tuple().
2324
"""
2425

25-
tensors: list[Tensor]
26-
batch_size: int = 32
26+
# each item must be indexable (usually torch.tensor, np.array or pd.Series)
27+
tensors: list[Tensor | np.ndarray]
28+
batch_size: int = 64
2729
shuffle: bool = False
2830
collate_fn: Callable = tuple
2931

aviary/wrenformer/data.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

33
import json
4+
from typing import Literal
45

56
import numpy as np
7+
import pandas as pd
68
import torch
79
from pymatgen.core import Composition
810
from torch import LongTensor, Tensor, nn
911

1012
from aviary import PKG_DIR
13+
from aviary.data import InMemoryDataLoader
1114
from aviary.wren.data import parse_aflow_wyckoff_str
1215

1316

@@ -118,3 +121,56 @@ def get_composition_embedding(formula: str) -> Tensor:
118121
combined_features = torch.cat([element_ratios, element_features], dim=1).float()
119122

120123
return combined_features
124+
125+
126+
def df_to_in_mem_dataloader(
127+
df: pd.DataFrame,
128+
target_col: str,
129+
input_col: str = "wyckoff",
130+
id_col: str = "material_id",
131+
embedding_type: Literal["wyckoff", "composition"] = "wyckoff",
132+
device: str = None,
133+
**kwargs,
134+
) -> InMemoryDataLoader:
135+
"""Construct an InMemoryDataLoader with Wrenformer batch collation from a dataframe.
136+
Can also be used for Roostformer.
137+
138+
Args:
139+
df (pd.DataFrame): Expected to have columns input_col, target_col, id_col.
140+
target_col (str): Column name holding the target values.
141+
input_col (str): Column name holding the input values (Aflow Wyckoff labels or composition
142+
strings) from which initial embeddings will be constructed. Defaults to "wyckoff".
143+
id_col (str): Column name holding material identifiers. Defaults to "material_id".
144+
embedding_type ('wyckoff' | 'composition'): Defaults to "wyckoff".
145+
device (str): torch.device to load tensors onto. Defaults to
146+
"cuda" if torch.cuda.is_available() else "cpu".
147+
kwargs (dict): Keyword arguments like batch_size: int and shuffle: bool
148+
to pass to InMemoryDataLoader. Defaults to None.
149+
150+
Returns:
151+
InMemoryDataLoader: Ready for use in model.evaluate(data_loader) or
152+
[model(x) for x in data_loader]
153+
"""
154+
if device is None:
155+
device = "cuda" if torch.cuda.is_available() else "cpu"
156+
157+
if embedding_type not in ["wyckoff", "composition"]:
158+
raise ValueError(f"{embedding_type = } must be 'wyckoff' or 'composition'")
159+
160+
initial_embeddings = df[input_col].map(
161+
wyckoff_embedding_from_aflow_str
162+
if embedding_type == "wyckoff"
163+
else get_composition_embedding
164+
)
165+
targets = torch.tensor(df[target_col], device=device)
166+
if targets.dtype == torch.bool:
167+
targets = targets.long() # convert binary classification targets to 0 and 1
168+
inputs = np.empty(len(initial_embeddings), dtype=object)
169+
for idx, tensor in enumerate(initial_embeddings):
170+
inputs[idx] = tensor.to(device)
171+
172+
ids = df[id_col].to_numpy()
173+
data_loader = InMemoryDataLoader(
174+
[inputs, targets, ids], collate_fn=collate_batch, **kwargs
175+
)
176+
return data_loader

examples/mat_bench/run_wrenformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run_wrenformer_on_matbench(
6464
target_col=target,
6565
task_type=task_type,
6666
# set to None to disable logging
67-
wandb_project=kwargs.get("wandb_project", "mp-wbm"),
67+
wandb_project=kwargs.pop("wandb_project", "mp-wbm"),
6868
id_col=id_col,
6969
run_params={
7070
"dataset": dataset_name,

examples/wrenformer.py

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,9 @@
99

1010
from aviary import ROOT
1111
from aviary.core import Normalizer, TaskType
12-
from aviary.data import InMemoryDataLoader
1312
from aviary.losses import RobustL1Loss
1413
from aviary.utils import get_metrics
15-
from aviary.wrenformer.data import (
16-
collate_batch,
17-
get_composition_embedding,
18-
wyckoff_embedding_from_aflow_str,
19-
)
14+
from aviary.wrenformer.data import df_to_in_mem_dataloader
2015
from aviary.wrenformer.model import Wrenformer
2116
from aviary.wrenformer.utils import print_walltime
2217

@@ -42,6 +37,7 @@ def run_wrenformer(
4237
target_col: str,
4338
epochs: int,
4439
timestamp: str = None,
40+
input_col: str = None,
4541
id_col: str = "material_id",
4642
n_attn_layers: int = 4,
4743
wandb_project: str = None,
@@ -67,6 +63,8 @@ def run_wrenformer(
6763
train_df (pd.DataFrame): Dataframe containing the training data.
6864
test_df (pd.DataFrame): Dataframe containing the test data.
6965
target_col (str): Name of df column containing the target values.
66+
input_col (str): Name of df column containing the input values. Defaults to 'wyckoff' if
67+
'wren' in run_name else 'composition'.
7068
id_col (str): Name of df column containing material IDs.
7169
epochs (int): How many epochs to train for. Defaults to 100.
7270
timestamp (str): Will be included in run_params and used as file name prefix for model
@@ -118,23 +116,16 @@ def run_wrenformer(
118116
device = "cuda" if torch.cuda.is_available() else "cpu"
119117
print(f"Pytorch running on {device=}")
120118

121-
for label, df in [("training set", train_df), ("test set", test_df)]:
122-
if "wren" in run_name.lower():
123-
err_msg = "Missing 'wyckoff' column in dataframe. "
124-
err_msg += (
125-
"Please generate Aflow Wyckoff labels ahead of time."
126-
if "structure" in df
127-
else "Trying to deploy Wrenformer on composition-only task?"
128-
)
129-
assert "wyckoff" in df, err_msg
130-
with print_walltime(
131-
start_desc=f"Generating Wyckoff embeddings for {label}", newline=False
132-
):
133-
df["features"] = df.wyckoff.map(wyckoff_embedding_from_aflow_str)
134-
elif "roost" in run_name.lower():
135-
df["features"] = df.composition.map(get_composition_embedding)
136-
else:
137-
raise ValueError(f"{run_name = } must contain 'roost' or 'wren'")
119+
if "wren" in run_name.lower():
120+
input_col = input_col or "wyckoff"
121+
embedding_type = "wyckoff"
122+
elif "roost" in run_name.lower():
123+
input_col = input_col or "composition"
124+
embedding_type = "composition"
125+
else:
126+
raise ValueError(
127+
f"{run_name = } must contain 'roost' or 'wren' (case insensitive)"
128+
)
138129

139130
robust = "robust" in run_name.lower()
140131
loss_func = (
@@ -145,42 +136,24 @@ def run_wrenformer(
145136
loss_dict = {target_col: (task_type, loss_func)}
146137
normalizer_dict = {target_col: Normalizer() if task_type == reg_key else None}
147138

148-
features, targets, ids = (train_df[x] for x in ["features", target_col, id_col])
149-
targets = torch.tensor(targets, device=device)
150-
if targets.dtype == torch.bool:
151-
targets = targets.long()
152-
inputs = np.empty(len(features), dtype=object)
153-
for idx, tensor in enumerate(features):
154-
inputs[idx] = tensor.to(device)
155-
156-
train_loader = InMemoryDataLoader(
157-
[inputs, targets, ids],
158-
batch_size=batch_size,
159-
shuffle=True,
160-
collate_fn=collate_batch,
139+
data_loader_kwargs = dict(
140+
target_col=target_col,
141+
input_col=input_col,
142+
id_col=id_col,
143+
embedding_type=embedding_type,
161144
)
162-
163-
features, targets, ids = (test_df[x] for x in ["features", target_col, id_col])
164-
targets = torch.tensor(targets, device=device)
165-
if targets.dtype == torch.bool:
166-
targets = targets.long()
167-
inputs = np.empty(len(features), dtype=object)
168-
for idx, tensor in enumerate(features):
169-
inputs[idx] = tensor.to(device)
170-
171-
test_loader = InMemoryDataLoader(
172-
[inputs, targets, ids], batch_size=512, collate_fn=collate_batch
145+
train_loader = df_to_in_mem_dataloader(
146+
train_df, batch_size=batch_size, shuffle=True, **data_loader_kwargs
173147
)
174148

175-
# n_features is the length of the embedding vector for a Wyckoff position encoding
176-
# the element type (usually 200-dim matscholar embeddings) and Wyckoff position (see
177-
# 'bra-alg-off.json') + 1 for the weight of that element/Wyckoff position in the
178-
# material's composition
179-
embedding_len = features[0].shape[-1]
180-
assert embedding_len in (
181-
200 + 1,
182-
200 + 1 + 444,
183-
) # Roost and Wren embedding size resp.
149+
test_loader = df_to_in_mem_dataloader(test_df, batch_size=512, **data_loader_kwargs)
150+
151+
# embedding_len is the length of the embedding vector for a Wyckoff position encoding the
152+
# element type (usually 200-dim matscholar embeddings) and Wyckoff position (see
153+
# 'bra-alg-off.json') + 1 for the weight of that Wyckoff position (or element) in the material
154+
embedding_len = train_loader.tensors[0][0].shape[-1]
155+
# Roost and Wren embedding size resp.
156+
assert embedding_len in (200 + 1, 200 + 1 + 444), f"{embedding_len=}"
184157

185158
model_params = dict(
186159
# 1 for regression, n_classes for classification
@@ -242,6 +215,7 @@ def run_wrenformer(
242215
"training_samples": len(train_df),
243216
"test_samples": len(test_df),
244217
"trainable_params": model.num_params,
218+
"task_type": task_type,
245219
"swa": {
246220
"start": swa_start,
247221
"epochs": int(swa_start * epochs),
@@ -272,6 +246,7 @@ def run_wrenformer(
272246
for epoch in range(epochs):
273247
if verbose:
274248
print(f"Epoch {epoch + 1}/{epochs}")
249+
275250
train_metrics = model.evaluate(
276251
train_loader,
277252
loss_dict,
@@ -333,9 +308,9 @@ def run_wrenformer(
333308
predictions = predictions.softmax(dim=1)
334309

335310
predictions = predictions.cpu().numpy().squeeze()
336-
targets = targets.cpu().numpy()
311+
targets = test_df[target_col]
337312
pred_col = f"{target_col}_pred"
338-
test_df[pred_col] = predictions.tolist()
313+
test_df[pred_col] = predictions.tolist() # requires shuffle=False for test_loader
339314

340315
test_metrics = get_metrics(targets, predictions, task_type)
341316
test_metrics["test_size"] = len(test_df)

0 commit comments

Comments
 (0)