Skip to content

Commit 4f0a257

Browse files
committed
Address akihironitta review: annotations, idioms, style, drop string forward ref
1 parent e5fa591 commit 4f0a257

2 files changed

Lines changed: 51 additions & 35 deletions

File tree

examples/relbench_example.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
``python relbench_example.py``
1515
``python relbench_example.py --epochs 50 --hidden_channels 128``
1616
"""
17+
1718
import argparse
19+
from typing import Tuple
1820

1921
import torch
2022
import torch.nn.functional as F
@@ -24,7 +26,8 @@
2426
from torch_geometric.utils import from_relbench
2527

2628
parser = argparse.ArgumentParser(
27-
description='Train a heterogeneous GNN on a RelBench dataset.')
29+
description='Train a heterogeneous GNN on a RelBench dataset.'
30+
)
2831
parser.add_argument('--hidden_channels', type=int, default=64)
2932
parser.add_argument('--lr', type=float, default=0.005)
3033
parser.add_argument('--epochs', type=int, default=30)
@@ -38,14 +41,16 @@
3841
dataset = get_dataset('rel-f1', download=True)
3942
db = dataset.get_db()
4043
data = from_relbench(db)
41-
print(f'Graph: {len(data.node_types)} node types, '
42-
f'{len(data.edge_types)} edge types')
44+
print(
45+
f'Graph: {len(data.node_types)} node types, '
46+
f'{len(data.edge_types)} edge types'
47+
)
4348

4449
# 2. Prepare a node regression target.
4550
# `from_relbench` preserves the original DataFrame column order from RelBench.
4651
# In rel-f1, the 'standings' table has 'points' as its first numeric column:
4752
target_type = 'standings'
48-
y = data[target_type].x[:, 0].clone() # points column (index 0 in rel-f1)
53+
y = data[target_type].x[:, 0] # points column (index 0 in rel-f1)
4954
data[target_type].x = data[target_type].x[:, 1:] # remove from input features
5055

5156
# 3. Clean up features — fill NaN and standardize per column:
@@ -65,9 +70,9 @@
6570
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
6671
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
6772
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
68-
train_mask[perm[:int(0.6 * num_nodes)]] = True
69-
val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True
70-
test_mask[perm[int(0.8 * num_nodes):]] = True
73+
train_mask[perm[: int(0.6 * num_nodes)]] = True
74+
val_mask[perm[int(0.6 * num_nodes) : int(0.8 * num_nodes)]] = True
75+
test_mask[perm[int(0.8 * num_nodes) :]] = True
7176

7277
# Normalize target using training set statistics only (prevents data leakage):
7378
y_mean = y[train_mask].mean()
@@ -92,7 +97,11 @@ def __init__(self, hidden_channels: int) -> None:
9297
self.conv2 = SAGEConv((-1, -1), hidden_channels)
9398
self.lin = Linear(-1, 1)
9499

95-
def forward(self, x, edge_index):
100+
def forward(
101+
self,
102+
x: torch.Tensor,
103+
edge_index: torch.Tensor,
104+
) -> torch.Tensor:
96105
x = self.conv1(x, edge_index).relu()
97106
x = self.conv2(x, edge_index).relu()
98107
return self.lin(x)
@@ -108,40 +117,45 @@ def forward(self, x, edge_index):
108117
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
109118

110119

111-
def train() -> float:
120+
def train() -> torch.Tensor:
112121
model.train()
113122
optimizer.zero_grad()
114123
pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1)
115124
loss = F.mse_loss(pred[train_mask], y_norm[train_mask])
116125
loss.backward()
117126
optimizer.step()
118-
return float(loss)
127+
return loss
119128

120129

121130
@torch.no_grad()
122-
def test():
131+
def test() -> Tuple[float, float, float]:
123132
model.eval()
124133
pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1)
125-
pred_orig = pred * y_std + y_mean # denormalize for interpretable MAE
134+
# denormalize for interpretable MAE
135+
pred *= y_std
136+
pred += y_mean
126137

127-
train_mae = float((pred_orig[train_mask] - y[train_mask]).abs().mean())
128-
val_mae = float((pred_orig[val_mask] - y[val_mask]).abs().mean())
129-
test_mae = float((pred_orig[test_mask] - y[test_mask]).abs().mean())
138+
train_mae = float((pred[train_mask] - y[train_mask]).abs().mean())
139+
val_mae = float((pred[val_mask] - y[val_mask]).abs().mean())
140+
test_mae = float((pred[test_mask] - y[test_mask]).abs().mean())
130141
return train_mae, val_mae, test_mae
131142

132143

133-
print(
134-
f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...')
144+
print(f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...')
135145
print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n')
136146

137147
for epoch in range(1, args.epochs + 1):
138148
loss = train()
139149
if epoch % 5 == 0 or epoch == 1:
140150
train_mae, val_mae, test_mae = test()
141-
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
142-
f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
143-
f'Test MAE: {test_mae:.2f} points')
151+
print(
152+
f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
153+
f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
154+
f'Test MAE: {test_mae:.2f} points'
155+
)
144156

145157
train_mae, val_mae, test_mae = test()
146-
print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
147-
f'Test MAE: {test_mae:.2f} points')
158+
print(
159+
f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
160+
f'Test MAE: {test_mae:.2f} points'
161+
)

torch_geometric/utils/relbench.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import numpy as np
44
import torch
55

6-
import torch_geometric
76
from torch_geometric.data import HeteroData
87
from torch_geometric.utils import sort_edge_index
98

109

11-
def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData':
10+
def from_relbench(db: Any) -> HeteroData:
1211
r"""Converts a :class:`relbench.base.Database` object into a
1312
:class:`~torch_geometric.data.HeteroData` object.
1413
@@ -29,7 +28,7 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData':
2928
type and each foreign key relationship maps to a pair of directed
3029
edge types.
3130
32-
Example:
31+
Examples:
3332
>>> from relbench.base import Database, Table
3433
>>> import pandas as pd
3534
>>> users = Table(
@@ -70,10 +69,11 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData':
7069

7170
# Convert numeric feature columns into a node feature tensor:
7271
feature_cols = [
73-
col for col in df.columns
72+
col
73+
for col in df.columns
7474
if col not in exclude_cols and df[col].dtype.kind in ('i', 'f')
7575
]
76-
if len(feature_cols) > 0:
76+
if feature_cols:
7777
x_np = df[feature_cols].to_numpy(
7878
dtype=np.float32,
7979
na_value=np.nan,
@@ -84,16 +84,17 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData':
8484
if table.time_col is not None:
8585
time_ser = df[table.time_col]
8686
if time_ser.dtype in [
87-
np.dtype("datetime64[s]"),
88-
np.dtype("datetime64[ns]"),
87+
np.dtype('datetime64[s]'),
88+
np.dtype('datetime64[ns]'),
8989
]:
90-
unix_time = time_ser.astype("int64").values
91-
if time_ser.dtype == np.dtype("datetime64[ns]"):
90+
unix_time = time_ser.astype('int64').values
91+
if time_ser.dtype == np.dtype('datetime64[ns]'):
9292
unix_time = unix_time // 10**9
9393
data[table_name].time = torch.from_numpy(unix_time)
9494
else:
9595
data[table_name].time = torch.from_numpy(
96-
time_ser.values.astype(np.float64), )
96+
time_ser.values.astype(np.float64)
97+
)
9798

9899
# Create edges from foreign key relationships:
99100
for fkey_col, pkey_table_name in table.fkey_col_to_pkey_table.items():
@@ -103,17 +104,18 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData':
103104
mask = ~pkey_index.isna()
104105
fkey_idx = torch.arange(len(pkey_index))
105106
pkey_idx = torch.from_numpy(
106-
pkey_index[mask].to_numpy(dtype=np.int64), )
107+
pkey_index[mask].to_numpy(dtype=np.int64)
108+
)
107109
fkey_idx = fkey_idx[torch.from_numpy(mask.to_numpy(dtype=bool))]
108110

109111
# Forward edge: fkey table -> pkey table
110112
edge_index = torch.stack([fkey_idx, pkey_idx], dim=0)
111-
edge_type = (table_name, f"f2p_{fkey_col}", pkey_table_name)
113+
edge_type = (table_name, f'f2p_{fkey_col}', pkey_table_name)
112114
data[edge_type].edge_index = sort_edge_index(edge_index)
113115

114116
# Reverse edge: pkey table -> fkey table
115117
edge_index = torch.stack([pkey_idx, fkey_idx], dim=0)
116-
edge_type = (pkey_table_name, f"rev_f2p_{fkey_col}", table_name)
118+
edge_type = (pkey_table_name, f'rev_f2p_{fkey_col}', table_name)
117119
data[edge_type].edge_index = sort_edge_index(edge_index)
118120

119121
data.validate()

0 commit comments

Comments
 (0)