Skip to content

Commit 6d3e0a6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4f0a257 commit 6d3e0a6

2 files changed

Lines changed: 18 additions & 27 deletions

File tree

examples/relbench_example.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from torch_geometric.utils import from_relbench
2727

2828
parser = argparse.ArgumentParser(
29-
description='Train a heterogeneous GNN on a RelBench dataset.'
30-
)
29+
description='Train a heterogeneous GNN on a RelBench dataset.')
3130
parser.add_argument('--hidden_channels', type=int, default=64)
3231
parser.add_argument('--lr', type=float, default=0.005)
3332
parser.add_argument('--epochs', type=int, default=30)
@@ -41,10 +40,8 @@
4140
dataset = get_dataset('rel-f1', download=True)
4241
db = dataset.get_db()
4342
data = from_relbench(db)
44-
print(
45-
f'Graph: {len(data.node_types)} node types, '
46-
f'{len(data.edge_types)} edge types'
47-
)
43+
print(f'Graph: {len(data.node_types)} node types, '
44+
f'{len(data.edge_types)} edge types')
4845

4946
# 2. Prepare a node regression target.
5047
# `from_relbench` preserves the original DataFrame column order from RelBench.
@@ -70,9 +67,9 @@
7067
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
7168
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
7269
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
73-
train_mask[perm[: int(0.6 * num_nodes)]] = True
74-
val_mask[perm[int(0.6 * num_nodes) : int(0.8 * num_nodes)]] = True
75-
test_mask[perm[int(0.8 * num_nodes) :]] = True
70+
train_mask[perm[:int(0.6 * num_nodes)]] = True
71+
val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True
72+
test_mask[perm[int(0.8 * num_nodes):]] = True
7673

7774
# Normalize target using training set statistics only (prevents data leakage):
7875
y_mean = y[train_mask].mean()
@@ -141,21 +138,18 @@ def test() -> Tuple[float, float, float]:
141138
return train_mae, val_mae, test_mae
142139

143140

144-
print(f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...')
141+
print(
142+
f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...')
145143
print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n')
146144

147145
for epoch in range(1, args.epochs + 1):
148146
loss = train()
149147
if epoch % 5 == 0 or epoch == 1:
150148
train_mae, val_mae, test_mae = test()
151-
print(
152-
f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
153-
f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
154-
f'Test MAE: {test_mae:.2f} points'
155-
)
149+
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
150+
f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
151+
f'Test MAE: {test_mae:.2f} points')
156152

157153
train_mae, val_mae, test_mae = test()
158-
print(
159-
f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
160-
f'Test MAE: {test_mae:.2f} points'
161-
)
154+
print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
155+
f'Test MAE: {test_mae:.2f} points')

torch_geometric/utils/relbench.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def from_relbench(db: Any) -> HeteroData:
6969

7070
# Convert numeric feature columns into a node feature tensor:
7171
feature_cols = [
72-
col
73-
for col in df.columns
72+
col for col in df.columns
7473
if col not in exclude_cols and df[col].dtype.kind in ('i', 'f')
7574
]
7675
if feature_cols:
@@ -84,17 +83,16 @@ def from_relbench(db: Any) -> HeteroData:
8483
if table.time_col is not None:
8584
time_ser = df[table.time_col]
8685
if time_ser.dtype in [
87-
np.dtype('datetime64[s]'),
88-
np.dtype('datetime64[ns]'),
86+
np.dtype('datetime64[s]'),
87+
np.dtype('datetime64[ns]'),
8988
]:
9089
unix_time = time_ser.astype('int64').values
9190
if time_ser.dtype == np.dtype('datetime64[ns]'):
9291
unix_time = unix_time // 10**9
9392
data[table_name].time = torch.from_numpy(unix_time)
9493
else:
9594
data[table_name].time = torch.from_numpy(
96-
time_ser.values.astype(np.float64)
97-
)
95+
time_ser.values.astype(np.float64))
9896

9997
# Create edges from foreign key relationships:
10098
for fkey_col, pkey_table_name in table.fkey_col_to_pkey_table.items():
@@ -104,8 +102,7 @@ def from_relbench(db: Any) -> HeteroData:
104102
mask = ~pkey_index.isna()
105103
fkey_idx = torch.arange(len(pkey_index))
106104
pkey_idx = torch.from_numpy(
107-
pkey_index[mask].to_numpy(dtype=np.int64)
108-
)
105+
pkey_index[mask].to_numpy(dtype=np.int64))
109106
fkey_idx = fkey_idx[torch.from_numpy(mask.to_numpy(dtype=bool))]
110107

111108
# Forward edge: fkey table -> pkey table

0 commit comments

Comments
 (0)