1414 ``python relbench_example.py``
1515 ``python relbench_example.py --epochs 50 --hidden_channels 128``
1616"""
17+
1718import argparse
19+ from typing import Tuple
1820
1921import torch
2022import torch .nn .functional as F
2426from torch_geometric .utils import from_relbench
2527
2628parser = argparse .ArgumentParser (
27- description = 'Train a heterogeneous GNN on a RelBench dataset.' )
29+ description = 'Train a heterogeneous GNN on a RelBench dataset.'
30+ )
2831parser .add_argument ('--hidden_channels' , type = int , default = 64 )
2932parser .add_argument ('--lr' , type = float , default = 0.005 )
3033parser .add_argument ('--epochs' , type = int , default = 30 )
3841dataset = get_dataset ('rel-f1' , download = True )
3942db = dataset .get_db ()
4043data = from_relbench (db )
41- print (f'Graph: { len (data .node_types )} node types, '
42- f'{ len (data .edge_types )} edge types' )
44+ print (
45+ f'Graph: { len (data .node_types )} node types, '
46+ f'{ len (data .edge_types )} edge types'
47+ )
4348
4449# 2. Prepare a node regression target.
4550# `from_relbench` preserves the original DataFrame column order from RelBench.
4651# In rel-f1, the 'standings' table has 'points' as its first numeric column:
4752target_type = 'standings'
48- y = data [target_type ].x [:, 0 ]. clone () # points column (index 0 in rel-f1)
53+ y = data [target_type ].x [:, 0 ] # points column (index 0 in rel-f1)
4954data [target_type ].x = data [target_type ].x [:, 1 :] # remove from input features
5055
5156# 3. Clean up features — fill NaN and standardize per column:
6570train_mask = torch .zeros (num_nodes , dtype = torch .bool )
6671val_mask = torch .zeros (num_nodes , dtype = torch .bool )
6772test_mask = torch .zeros (num_nodes , dtype = torch .bool )
68- train_mask [perm [:int (0.6 * num_nodes )]] = True
69- val_mask [perm [int (0.6 * num_nodes ): int (0.8 * num_nodes )]] = True
70- test_mask [perm [int (0.8 * num_nodes ):]] = True
73+ train_mask [perm [: int (0.6 * num_nodes )]] = True
74+ val_mask [perm [int (0.6 * num_nodes ) : int (0.8 * num_nodes )]] = True
75+ test_mask [perm [int (0.8 * num_nodes ) :]] = True
7176
7277# Normalize target using training set statistics only (prevents data leakage):
7378y_mean = y [train_mask ].mean ()
@@ -92,7 +97,11 @@ def __init__(self, hidden_channels: int) -> None:
9297 self .conv2 = SAGEConv ((- 1 , - 1 ), hidden_channels )
9398 self .lin = Linear (- 1 , 1 )
9499
95- def forward (self , x , edge_index ):
100+ def forward (
101+ self ,
102+ x : torch .Tensor ,
103+ edge_index : torch .Tensor ,
104+ ) -> torch .Tensor :
96105 x = self .conv1 (x , edge_index ).relu ()
97106 x = self .conv2 (x , edge_index ).relu ()
98107 return self .lin (x )
@@ -108,40 +117,45 @@ def forward(self, x, edge_index):
108117optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
109118
110119
111- def train () -> float :
120+ def train () -> torch . Tensor :
112121 model .train ()
113122 optimizer .zero_grad ()
114123 pred = model (data .x_dict , data .edge_index_dict )[target_type ].squeeze (- 1 )
115124 loss = F .mse_loss (pred [train_mask ], y_norm [train_mask ])
116125 loss .backward ()
117126 optimizer .step ()
118- return float ( loss )
127+ return loss
119128
120129
121130@torch .no_grad ()
122- def test ():
131+ def test () -> Tuple [ float , float , float ] :
123132 model .eval ()
124133 pred = model (data .x_dict , data .edge_index_dict )[target_type ].squeeze (- 1 )
125- pred_orig = pred * y_std + y_mean # denormalize for interpretable MAE
134+ # denormalize for interpretable MAE
135+ pred *= y_std
136+ pred += y_mean
126137
127- train_mae = float ((pred_orig [train_mask ] - y [train_mask ]).abs ().mean ())
128- val_mae = float ((pred_orig [val_mask ] - y [val_mask ]).abs ().mean ())
129- test_mae = float ((pred_orig [test_mask ] - y [test_mask ]).abs ().mean ())
138+ train_mae = float ((pred [train_mask ] - y [train_mask ]).abs ().mean ())
139+ val_mae = float ((pred [val_mask ] - y [val_mask ]).abs ().mean ())
140+ test_mae = float ((pred [test_mask ] - y [test_mask ]).abs ().mean ())
130141 return train_mae , val_mae , test_mae
131142
132143
133- print (
134- f'\n Training { args .epochs } epochs on "{ target_type } " point prediction...' )
144+ print (f'\n Training { args .epochs } epochs on "{ target_type } " point prediction...' )
135145print (f'Target stats (train): mean={ y_mean :.2f} , std={ y_std :.2f} \n ' )
136146
137147for epoch in range (1 , args .epochs + 1 ):
138148 loss = train ()
139149 if epoch % 5 == 0 or epoch == 1 :
140150 train_mae , val_mae , test_mae = test ()
141- print (f'Epoch: { epoch :03d} , Loss: { loss :.4f} , '
142- f'Train MAE: { train_mae :.2f} , Val MAE: { val_mae :.2f} , '
143- f'Test MAE: { test_mae :.2f} points' )
151+ print (
152+ f'Epoch: { epoch :03d} , Loss: { loss :.4f} , '
153+ f'Train MAE: { train_mae :.2f} , Val MAE: { val_mae :.2f} , '
154+ f'Test MAE: { test_mae :.2f} points'
155+ )
144156
145157train_mae , val_mae , test_mae = test ()
146- print (f'\n Final — Train MAE: { train_mae :.2f} , Val MAE: { val_mae :.2f} , '
147- f'Test MAE: { test_mae :.2f} points' )
158+ print (
159+ f'\n Final — Train MAE: { train_mae :.2f} , Val MAE: { val_mae :.2f} , '
160+ f'Test MAE: { test_mae :.2f} points'
161+ )
0 commit comments