|
| 1 | +"""Example demonstrating how to use ``from_relbench`` to convert a RelBench |
| 2 | +relational database into a PyG HeteroData graph and train a heterogeneous |
| 3 | +GNN for node-level prediction. |
| 4 | +
|
| 5 | +This example loads the Formula 1 RelBench dataset, converts it into a |
| 6 | +heterogeneous graph using ``from_relbench``, and trains a 2-layer GraphSAGE |
| 7 | +model (via ``to_hetero``) to predict championship standings points from |
| 8 | +the graph structure and node features. |
| 9 | +
|
| 10 | +Requirements: |
| 11 | + ``pip install relbench`` |
| 12 | +
|
| 13 | +Usage: |
| 14 | + ``python relbench_example.py`` |
| 15 | + ``python relbench_example.py --epochs 50 --hidden_channels 128`` |
| 16 | +""" |
| 17 | +import argparse |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.nn.functional as F |
| 21 | +from relbench.datasets import get_dataset |
| 22 | + |
| 23 | +from torch_geometric.nn import Linear, SAGEConv, to_hetero |
| 24 | +from torch_geometric.utils import from_relbench |
| 25 | + |
| 26 | +parser = argparse.ArgumentParser( |
| 27 | + description='Train a heterogeneous GNN on a RelBench dataset.') |
| 28 | +parser.add_argument('--hidden_channels', type=int, default=64) |
| 29 | +parser.add_argument('--lr', type=float, default=0.005) |
| 30 | +parser.add_argument('--epochs', type=int, default=30) |
| 31 | +args = parser.parse_args() |
| 32 | + |
| 33 | +torch.manual_seed(42) |
| 34 | +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 35 | + |
| 36 | +# 1. Load a RelBench dataset and convert to HeteroData: |
| 37 | +print('Loading RelBench rel-f1 dataset...') |
| 38 | +dataset = get_dataset('rel-f1', download=True) |
| 39 | +db = dataset.get_db() |
| 40 | +data = from_relbench(db) |
| 41 | +print(f'Graph: {len(data.node_types)} node types, ' |
| 42 | + f'{len(data.edge_types)} edge types') |
| 43 | + |
| 44 | +# 2. Prepare a node regression target. |
| 45 | +# `from_relbench` preserves the original DataFrame column order from RelBench. |
| 46 | +# In rel-f1, the 'standings' table has 'points' as its first numeric column: |
| 47 | +target_type = 'standings' |
| 48 | +y = data[target_type].x[:, 0].clone() # points column (index 0 in rel-f1) |
| 49 | +data[target_type].x = data[target_type].x[:, 1:] # remove from input features |
| 50 | + |
| 51 | +# 3. Clean up features — fill NaN and standardize per column: |
| 52 | +for node_type in data.node_types: |
| 53 | + if hasattr(data[node_type], 'x') and data[node_type].x is not None: |
| 54 | + x = torch.nan_to_num(data[node_type].x, nan=0.0) |
| 55 | + std, mean = torch.std_mean(x, dim=0) |
| 56 | + std[std == 0] = 1.0 # avoid division by zero for constant columns |
| 57 | + data[node_type].x = (x - mean) / std |
| 58 | + else: |
| 59 | + # Zero-feature placeholder for featureless node types (e.g. drivers): |
| 60 | + data[node_type].x = torch.zeros(data[node_type].num_nodes, 1) |
| 61 | + |
| 62 | +# 4. Create train/val/test splits (60/20/20) before computing target stats: |
| 63 | +num_nodes = data[target_type].num_nodes |
| 64 | +perm = torch.randperm(num_nodes) |
| 65 | +train_mask = torch.zeros(num_nodes, dtype=torch.bool) |
| 66 | +val_mask = torch.zeros(num_nodes, dtype=torch.bool) |
| 67 | +test_mask = torch.zeros(num_nodes, dtype=torch.bool) |
| 68 | +train_mask[perm[:int(0.6 * num_nodes)]] = True |
| 69 | +val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True |
| 70 | +test_mask[perm[int(0.8 * num_nodes):]] = True |
| 71 | + |
| 72 | +# Normalize target using training set statistics only (prevents data leakage): |
| 73 | +y_mean = y[train_mask].mean() |
| 74 | +y_std = y[train_mask].std() |
| 75 | +y_std = y_std if y_std > 0 else torch.tensor(1.0) |
| 76 | +y_norm = (y - y_mean) / y_std |
| 77 | + |
| 78 | +# 5. Move all tensors to device — including masks to prevent device mismatch: |
| 79 | +data = data.to(device) |
| 80 | +y = y.to(device) |
| 81 | +y_norm = y_norm.to(device) |
| 82 | +train_mask = train_mask.to(device) |
| 83 | +val_mask = val_mask.to(device) |
| 84 | +test_mask = test_mask.to(device) |
| 85 | + |
| 86 | + |
| 87 | +# 6. Define a 2-layer GraphSAGE model with lazy input size inference: |
| 88 | +class GNN(torch.nn.Module): |
| 89 | + def __init__(self, hidden_channels: int) -> None: |
| 90 | + super().__init__() |
| 91 | + self.conv1 = SAGEConv((-1, -1), hidden_channels) |
| 92 | + self.conv2 = SAGEConv((-1, -1), hidden_channels) |
| 93 | + self.lin = Linear(-1, 1) |
| 94 | + |
| 95 | + def forward(self, x, edge_index): |
| 96 | + x = self.conv1(x, edge_index).relu() |
| 97 | + x = self.conv2(x, edge_index).relu() |
| 98 | + return self.lin(x) |
| 99 | + |
| 100 | + |
| 101 | +model = GNN(args.hidden_channels) |
| 102 | +model = to_hetero(model, data.metadata(), aggr='sum').to(device) |
| 103 | + |
| 104 | +# Initialize lazy parameters via a single dry-run forward pass: |
| 105 | +with torch.no_grad(): |
| 106 | + model(data.x_dict, data.edge_index_dict) |
| 107 | + |
| 108 | +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) |
| 109 | + |
| 110 | + |
| 111 | +def train() -> float: |
| 112 | + model.train() |
| 113 | + optimizer.zero_grad() |
| 114 | + pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1) |
| 115 | + loss = F.mse_loss(pred[train_mask], y_norm[train_mask]) |
| 116 | + loss.backward() |
| 117 | + optimizer.step() |
| 118 | + return float(loss) |
| 119 | + |
| 120 | + |
| 121 | +@torch.no_grad() |
| 122 | +def test(): |
| 123 | + model.eval() |
| 124 | + pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1) |
| 125 | + pred_orig = pred * y_std + y_mean # denormalize for interpretable MAE |
| 126 | + |
| 127 | + train_mae = float((pred_orig[train_mask] - y[train_mask]).abs().mean()) |
| 128 | + val_mae = float((pred_orig[val_mask] - y[val_mask]).abs().mean()) |
| 129 | + test_mae = float((pred_orig[test_mask] - y[test_mask]).abs().mean()) |
| 130 | + return train_mae, val_mae, test_mae |
| 131 | + |
| 132 | + |
| 133 | +print( |
| 134 | + f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') |
| 135 | +print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n') |
| 136 | + |
| 137 | +for epoch in range(1, args.epochs + 1): |
| 138 | + loss = train() |
| 139 | + if epoch % 5 == 0 or epoch == 1: |
| 140 | + train_mae, val_mae, test_mae = test() |
| 141 | + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' |
| 142 | + f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' |
| 143 | + f'Test MAE: {test_mae:.2f} points') |
| 144 | + |
| 145 | +train_mae, val_mae, test_mae = test() |
| 146 | +print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' |
| 147 | + f'Test MAE: {test_mae:.2f} points') |
0 commit comments