Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Added

- Added `from_relbench` utility to convert RelBench databases into `HeteroData` ([#10628](https://github.com/pyg-team/pytorch_geometric/pull/10628))
- Added `txt2qa.py` example for synthetic multi-hop QA generation from text documents, supporting vLLM (local) and NVIDIA NIM (API) backends ([#10559](https://github.com/pyg-team/pytorch_geometric/pull/10559))

### Changed
Expand All @@ -23,7 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Fixed

- Fix MovieLens dataset incompatibility with `sentence-transformers>=5.0.0` ([#10668](https://github.com/pyg-team/pytorch_geometric/pull/10668)
- Fix MovieLens dataset incompatibility with `sentence-transformers>=5.0.0` ([#10668](https://github.com/pyg-team/pytorch_geometric/pull/10668))
- Removed an unnecessary device synchronization in `torch_geometric.utils.softmax` ([#10499](https://github.com/pyg-team/pytorch_geometric/pull/10499))
- Fixed loading of legacy HuggingFace BERT checkpoints ([#10631](https://github.com/pyg-team/pytorch_geometric/pull/10631))
- Fixed `return_attention_weights: bool` being not respected in `GATConv` and `GATv2Conv` ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596))
Expand Down
155 changes: 155 additions & 0 deletions examples/relbench_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Example demonstrating how to use ``from_relbench`` to convert a RelBench
relational database into a PyG HeteroData graph and train a heterogeneous
GNN for node-level prediction.

This example loads the Formula 1 RelBench dataset, converts it into a
heterogeneous graph using ``from_relbench``, and trains a 2-layer GraphSAGE
model (via ``to_hetero``) to predict championship standings points from
the graph structure and node features.

Requirements:
``pip install relbench``

Usage:
``python relbench_example.py``
``python relbench_example.py --epochs 50 --hidden_channels 128``
"""

import argparse
from typing import Tuple

import torch
import torch.nn.functional as F
from relbench.datasets import get_dataset

from torch_geometric.nn import Linear, SAGEConv, to_hetero
from torch_geometric.utils import from_relbench

parser = argparse.ArgumentParser(
description='Train a heterogeneous GNN on a RelBench dataset.')
parser.add_argument('--hidden_channels', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--epochs', type=int, default=30)
args = parser.parse_args()

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Load a RelBench dataset and convert to HeteroData:
print('Loading RelBench rel-f1 dataset...')
dataset = get_dataset('rel-f1', download=True)
db = dataset.get_db()
data = from_relbench(db)
print(f'Graph: {len(data.node_types)} node types, '
f'{len(data.edge_types)} edge types')

# 2. Prepare a node regression target.
# `from_relbench` preserves the original DataFrame column order from RelBench.
# In rel-f1, the 'standings' table has 'points' as its first numeric column:
target_type = 'standings'
y = data[target_type].x[:, 0] # points column (index 0 in rel-f1)
data[target_type].x = data[target_type].x[:, 1:] # remove from input features

# 3. Clean up features — fill NaN and standardize per column:
for node_type in data.node_types:
if hasattr(data[node_type], 'x') and data[node_type].x is not None:
x = torch.nan_to_num(data[node_type].x, nan=0.0)
std, mean = torch.std_mean(x, dim=0)
std[std == 0] = 1.0 # avoid division by zero for constant columns
data[node_type].x = (x - mean) / std
else:
# Zero-feature placeholder for featureless node types (e.g. drivers):
data[node_type].x = torch.zeros(data[node_type].num_nodes, 1)

# 4. Create train/val/test splits (60/20/20) before computing target stats:
num_nodes = data[target_type].num_nodes
perm = torch.randperm(num_nodes)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[perm[:int(0.6 * num_nodes)]] = True
val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True
test_mask[perm[int(0.8 * num_nodes):]] = True

# Normalize target using training set statistics only (prevents data leakage):
y_mean = y[train_mask].mean()
y_std = y[train_mask].std()
y_std = y_std if y_std > 0 else torch.tensor(1.0)
y_norm = (y - y_mean) / y_std

# 5. Move all tensors to device — including masks to prevent device mismatch:
data = data.to(device)
y = y.to(device)
y_norm = y_norm.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)


# 6. Define a 2-layer GraphSAGE model with lazy input size inference:
class GNN(torch.nn.Module):
def __init__(self, hidden_channels: int) -> None:
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), hidden_channels)
self.lin = Linear(-1, 1)

def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
return self.lin(x)


model = GNN(args.hidden_channels)
model = to_hetero(model, data.metadata(), aggr='sum').to(device)

# Initialize lazy parameters via a single dry-run forward pass:
with torch.no_grad():
model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)


def train() -> torch.Tensor:
model.train()
optimizer.zero_grad()
pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1)
loss = F.mse_loss(pred[train_mask], y_norm[train_mask])
loss.backward()
optimizer.step()
return loss


@torch.no_grad()
def test() -> Tuple[float, float, float]:
model.eval()
pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1)
# denormalize for interpretable MAE
pred *= y_std
pred += y_mean

train_mae = float((pred[train_mask] - y[train_mask]).abs().mean())
val_mae = float((pred[val_mask] - y[val_mask]).abs().mean())
test_mae = float((pred[test_mask] - y[test_mask]).abs().mean())
return train_mae, val_mae, test_mae


print(
f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...')
print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n')

for epoch in range(1, args.epochs + 1):
loss = train()
if epoch % 5 == 0 or epoch == 1:
train_mae, val_mae, test_mae = test()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
f'Test MAE: {test_mae:.2f} points')

train_mae, val_mae, test_mae = test()
print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, '
f'Test MAE: {test_mae:.2f} points')
208 changes: 208 additions & 0 deletions test/utils/test_relbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from types import SimpleNamespace
from typing import Any, Optional

import torch

from torch_geometric.testing import withPackage
from torch_geometric.utils import from_relbench


def _mock_table(
df: Any,
fkey_col_to_pkey_table: dict,
pkey_col: Optional[str] = None,
time_col: Optional[str] = None,
) -> SimpleNamespace:
"""Create a mock object that duck-types relbench.base.Table."""
return SimpleNamespace(
df=df,
fkey_col_to_pkey_table=fkey_col_to_pkey_table,
pkey_col=pkey_col,
time_col=time_col,
)


def _mock_database(table_dict: dict) -> SimpleNamespace:
"""Create a mock object that duck-types relbench.base.Database."""
return SimpleNamespace(table_dict=table_dict)


@withPackage('pandas')
def test_from_relbench():
import pandas as pd

df_users = pd.DataFrame({
'id': [0, 1, 2],
'age': [25, 30, 35],
'score': [1.0, 2.0, 3.0],
})
df_posts = pd.DataFrame({
'id': [0, 1, 2, 3],
'user_id': [0, 1, 0, 2],
'length': [100, 200, 150, 300],
})

users = _mock_table(
df=df_users,
fkey_col_to_pkey_table={},
pkey_col='id',
)
posts = _mock_table(
df=df_posts,
fkey_col_to_pkey_table={'user_id': 'users'},
pkey_col='id',
)

db = _mock_database(table_dict={'users': users, 'posts': posts})
data = from_relbench(db)

# Verify node types:
assert 'users' in data.node_types
assert 'posts' in data.node_types

# Verify node counts:
assert data['users'].num_nodes == 3
assert data['posts'].num_nodes == 4

# Verify numeric features were extracted:
assert data['users'].x is not None
assert data['users'].x.size() == (3, 2) # age, score
assert data['posts'].x is not None
assert data['posts'].x.size() == (4, 1) # length

# Verify feature values:
assert torch.allclose(
data['users'].x,
torch.tensor([[25, 1.0], [30, 2.0], [35, 3.0]]),
)

# Verify edge types (bidirectional fkey edges):
edge_types = data.edge_types
assert ('posts', 'f2p_user_id', 'users') in edge_types
assert ('users', 'rev_f2p_user_id', 'posts') in edge_types

# Verify edge index shapes (4 posts, each referencing a user):
fwd = data['posts', 'f2p_user_id', 'users'].edge_index
rev = data['users', 'rev_f2p_user_id', 'posts'].edge_index
assert fwd.size() == (2, 4)
assert rev.size() == (2, 4)


@withPackage('pandas')
def test_from_relbench_dangling_fkeys():
"""Test that dangling (NaN) foreign keys are filtered out."""
import pandas as pd

df_users = pd.DataFrame({'id': [0, 1]})
df_posts = pd.DataFrame({
'id': [0, 1, 2],
'user_id':
pd.array([0, None, 1], dtype=pd.Int64Dtype()),
})

users = _mock_table(
df=df_users,
fkey_col_to_pkey_table={},
pkey_col='id',
)
posts = _mock_table(
df=df_posts,
fkey_col_to_pkey_table={'user_id': 'users'},
pkey_col='id',
)

db = _mock_database(table_dict={'users': users, 'posts': posts})
data = from_relbench(db)

# Only 2 out of 3 posts have valid foreign keys:
fwd = data['posts', 'f2p_user_id', 'users'].edge_index
assert fwd.size() == (2, 2)


@withPackage('pandas')
def test_from_relbench_time_column():
"""Test that time columns are correctly converted."""
import pandas as pd

df = pd.DataFrame({
'id': [0, 1, 2],
'ts':
pd.to_datetime(['2024-01-01', '2024-01-02', '2024-01-03']),
'val': [10, 20, 30],
})

events = _mock_table(
df=df,
fkey_col_to_pkey_table={},
pkey_col='id',
time_col='ts',
)

db = _mock_database(table_dict={'events': events})
data = from_relbench(db)

assert data['events'].num_nodes == 3
assert data['events'].time is not None
assert data['events'].time.size() == (3, )
# Time column should not appear in features:
assert data['events'].x.size() == (3, 1) # only 'val'


@withPackage('pandas')
def test_from_relbench_no_features():
"""Test tables with only pkey/fkey columns and no numeric features."""
import pandas as pd

df = pd.DataFrame({
'id': [0, 1, 2],
'name': ['a', 'b', 'c'], # Non-numeric, should be excluded
})

items = _mock_table(
df=df,
fkey_col_to_pkey_table={},
pkey_col='id',
)

db = _mock_database(table_dict={'items': items})
data = from_relbench(db)

assert data['items'].num_nodes == 3
# No numeric feature columns (name is string, id is pkey):
assert not hasattr(data['items'], 'x') or data['items'].x is None


@withPackage('relbench')
def test_from_relbench_with_relbench():
"""Integration test using actual relbench objects."""
import pandas as pd
from relbench.base import Database, Table

df_users = pd.DataFrame({
'id': [0, 1, 2],
'age': [25, 30, 35],
})
df_posts = pd.DataFrame({
'id': [0, 1, 2],
'user_id': [0, 1, 0],
'score': [10, 20, 30],
})

users = Table(
df=df_users,
fkey_col_to_pkey_table={},
pkey_col='id',
)
posts = Table(
df=df_posts,
fkey_col_to_pkey_table={'user_id': 'users'},
pkey_col='id',
)

db = Database(table_dict={'users': users, 'posts': posts})
data = from_relbench(db)

assert 'users' in data.node_types
assert 'posts' in data.node_types
assert data['users'].num_nodes == 3
assert data['posts'].num_nodes == 3
2 changes: 2 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .convert import to_cugraph, from_cugraph
from .convert import to_dgl, from_dgl
from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles
from .relbench import from_relbench
from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
barabasi_albert_graph)
from ._negative_sampling import (negative_sampling, batched_negative_sampling,
Expand Down Expand Up @@ -135,6 +136,7 @@
'to_rdmol',
'from_smiles',
'to_smiles',
'from_relbench',
'erdos_renyi_graph',
'stochastic_blockmodel_graph',
'barabasi_albert_graph',
Expand Down
Loading
Loading