|
26 | 26 | from torch_geometric.utils import from_relbench |
27 | 27 |
|
28 | 28 | parser = argparse.ArgumentParser( |
29 | | - description='Train a heterogeneous GNN on a RelBench dataset.' |
30 | | -) |
| 29 | + description='Train a heterogeneous GNN on a RelBench dataset.') |
31 | 30 | parser.add_argument('--hidden_channels', type=int, default=64) |
32 | 31 | parser.add_argument('--lr', type=float, default=0.005) |
33 | 32 | parser.add_argument('--epochs', type=int, default=30) |
|
41 | 40 | dataset = get_dataset('rel-f1', download=True) |
42 | 41 | db = dataset.get_db() |
43 | 42 | data = from_relbench(db) |
44 | | -print( |
45 | | - f'Graph: {len(data.node_types)} node types, ' |
46 | | - f'{len(data.edge_types)} edge types' |
47 | | -) |
| 43 | +print(f'Graph: {len(data.node_types)} node types, ' |
| 44 | + f'{len(data.edge_types)} edge types') |
48 | 45 |
|
49 | 46 | # 2. Prepare a node regression target. |
50 | 47 | # `from_relbench` preserves the original DataFrame column order from RelBench. |
|
70 | 67 | train_mask = torch.zeros(num_nodes, dtype=torch.bool) |
71 | 68 | val_mask = torch.zeros(num_nodes, dtype=torch.bool) |
72 | 69 | test_mask = torch.zeros(num_nodes, dtype=torch.bool) |
73 | | -train_mask[perm[: int(0.6 * num_nodes)]] = True |
74 | | -val_mask[perm[int(0.6 * num_nodes) : int(0.8 * num_nodes)]] = True |
75 | | -test_mask[perm[int(0.8 * num_nodes) :]] = True |
| 70 | +train_mask[perm[:int(0.6 * num_nodes)]] = True |
| 71 | +val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True |
| 72 | +test_mask[perm[int(0.8 * num_nodes):]] = True |
76 | 73 |
|
77 | 74 | # Normalize target using training set statistics only (prevents data leakage): |
78 | 75 | y_mean = y[train_mask].mean() |
@@ -141,21 +138,18 @@ def test() -> Tuple[float, float, float]: |
141 | 138 | return train_mae, val_mae, test_mae |
142 | 139 |
|
143 | 140 |
|
144 | | -print(f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') |
| 141 | +print( |
| 142 | + f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') |
145 | 143 | print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n') |
146 | 144 |
|
147 | 145 | for epoch in range(1, args.epochs + 1): |
148 | 146 | loss = train() |
149 | 147 | if epoch % 5 == 0 or epoch == 1: |
150 | 148 | train_mae, val_mae, test_mae = test() |
151 | | - print( |
152 | | - f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' |
153 | | - f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' |
154 | | - f'Test MAE: {test_mae:.2f} points' |
155 | | - ) |
| 149 | + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' |
| 150 | + f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' |
| 151 | + f'Test MAE: {test_mae:.2f} points') |
156 | 152 |
|
157 | 153 | train_mae, val_mae, test_mae = test() |
158 | | -print( |
159 | | - f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' |
160 | | - f'Test MAE: {test_mae:.2f} points' |
161 | | -) |
| 154 | +print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' |
| 155 | + f'Test MAE: {test_mae:.2f} points') |
0 commit comments