Skip to content

Commit fe3e67c

Browse files
committed
Add relbench_example.py to demonstrate from_relbench with heterogeneous GNN training
1 parent d65b73c commit fe3e67c

1 file changed

Lines changed: 147 additions & 0 deletions

File tree

examples/relbench_example.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Example demonstrating how to use ``from_relbench`` to convert a RelBench
2+
relational database into a PyG HeteroData graph and train a heterogeneous
3+
GNN for node-level prediction.
4+
5+
This example loads the Formula 1 RelBench dataset, converts it into a
6+
heterogeneous graph using ``from_relbench``, and trains a 2-layer GraphSAGE
7+
model (via ``to_hetero``) to predict championship standings points from
8+
the graph structure and node features.
9+
10+
Requirements:
11+
``pip install relbench``
12+
13+
Usage:
14+
``python relbench_example.py``
15+
``python relbench_example.py --epochs 50 --hidden_channels 128``
16+
"""
17+
import argparse
18+
19+
import torch
20+
import torch.nn.functional as F
21+
from relbench.datasets import get_dataset
22+
23+
from torch_geometric.nn import Linear, SAGEConv, to_hetero
24+
from torch_geometric.utils import from_relbench
25+
26+
parser = argparse.ArgumentParser(
27+
description='Train a heterogeneous GNN on a RelBench dataset.')
28+
parser.add_argument('--hidden_channels', type=int, default=64)
29+
parser.add_argument('--lr', type=float, default=0.005)
30+
parser.add_argument('--epochs', type=int, default=30)
31+
args = parser.parse_args()
32+
33+
torch.manual_seed(42)
34+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35+
36+
# 1. Load a RelBench dataset and convert to HeteroData:
37+
print('Loading RelBench rel-f1 dataset...')
38+
dataset = get_dataset('rel-f1', download=True)
39+
db = dataset.get_db()
40+
data = from_relbench(db)
41+
print(f'Graph: {len(data.node_types)} node types, '
42+
f'{len(data.edge_types)} edge types')
43+
44+
# 2. Prepare a node regression target.
45+
# `from_relbench` preserves the original DataFrame column order from RelBench.
46+
# In rel-f1, the 'standings' table has 'points' as its first numeric column:
47+
target_type = 'standings'
48+
y = data[target_type].x[:, 0].clone() # points column (index 0 in rel-f1)
49+
data[target_type].x = data[target_type].x[:, 1:] # remove from input features
50+
51+
# 3. Clean up features — fill NaN and standardize per column:
52+
for node_type in data.node_types:
53+
if hasattr(data[node_type], 'x') and data[node_type].x is not None:
54+
x = torch.nan_to_num(data[node_type].x, nan=0.0)
55+
std, mean = torch.std_mean(x, dim=0)
56+
std[std == 0] = 1.0 # avoid division by zero for constant columns
57+
data[node_type].x = (x - mean) / std
58+
else:
59+
# Zero-feature placeholder for featureless node types (e.g. drivers):
60+
data[node_type].x = torch.zeros(data[node_type].num_nodes, 1)
61+
62+
# 4. Create train/val/test splits (60/20/20) before computing target stats:
63+
num_nodes = data[target_type].num_nodes
64+
perm = torch.randperm(num_nodes)
65+
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
66+
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
67+
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
68+
train_mask[perm[:int(0.6 * num_nodes)]] = True
69+
val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True
70+
test_mask[perm[int(0.8 * num_nodes):]] = True
71+
72+
# Normalize target using training set statistics only (prevents data leakage):
73+
y_mean = y[train_mask].mean()
74+
y_std = y[train_mask].std()
75+
y_std = y_std if y_std > 0 else torch.tensor(1.0)
76+
y_norm = (y - y_mean) / y_std
77+
78+
# 5. Move all tensors to device — including masks to prevent device mismatch:
79+
data = data.to(device)
80+
y = y.to(device)
81+
y_norm = y_norm.to(device)
82+
train_mask = train_mask.to(device)
83+
val_mask = val_mask.to(device)
84+
test_mask = test_mask.to(device)
85+
86+
87+
# 6. Define a 2-layer GraphSAGE model with lazy input size inference:
88+
class GNN(torch.nn.Module):
89+
def __init__(self, hidden_channels: int) -> None:
90+
super().__init__()
91+
self.conv1 = SAGEConv((-1, -1), hidden_channels)
92+
self.conv2 = SAGEConv((-1, -1), hidden_channels)
93+
self.lin = Linear(-1, 1)
94+
95+
def forward(self, x, edge_index):
96+
x = self.conv1(x, edge_index).relu()
97+
x = self.conv2(x, edge_index).relu()
98+
return self.lin(x)
99+
100+
101+
model = GNN(args.hidden_channels)
102+
model = to_hetero(model, data.metadata(), aggr='sum').to(device)
103+
104+
# Initialize lazy parameters via a single dry-run forward pass:
105+
with torch.no_grad():
106+
model(data.x_dict, data.edge_index_dict)
107+
108+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
109+
110+
111+
def train() -> float:
112+
model.train()
113+
optimizer.zero_grad()
114+
pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1)
115+
loss = F.mse_loss(pred[train_mask], y_norm[train_mask])
116+
loss.backward()
117+
optimizer.step()
118+
return float(loss)
119+
120+
121+
@torch.no_grad()
122+
def test():
123+
model.eval()
124+
pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1)
125+
pred_orig = pred * y_std + y_mean # denormalize for interpretable MAE
126+
127+
train_mae = float((pred_orig[train_mask] - y[train_mask]).abs().mean())
128+
val_mae = float((pred_orig[val_mask] - y[val_mask]).abs().mean())
129+
test_mae = float((pred_orig[test_mask] - y[test_mask]).abs().mean())
130+
return train_mae, val_mae, test_mae
131+
132+
133+
print(
134+
f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...')
135+
print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n')
136+
137+
for epoch in range(1, args.epochs + 1):
138+
loss = train()
139+
if epoch % 5 == 0 or epoch == 1:
140+
train_mae, val_mae, test_mae = test()
141+
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
142+
f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
143+
f'Test MAE: {test_mae:.2f} points')
144+
145+
train_mae, val_mae, test_mae = test()
146+
print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
147+
f'Test MAE: {test_mae:.2f} points')

0 commit comments

Comments
 (0)