Skip to content

Commit e1cb0cb

Browse files
committed
Add from_relbench utility to convert RelBench databases to HeteroData
1 parent b783d59 commit e1cb0cb

4 files changed

Lines changed: 316 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
77

88
### Added
99

10+
- Added `from_relbench` utility to convert RelBench databases into `HeteroData` ([#XXXX](https://github.com/pyg-team/pytorch_geometric/pull/XXXX))
11+
1012
### Changed
1113

1214
- Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596))

test/utils/test_relbench.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from types import SimpleNamespace
2+
3+
import pandas as pd
4+
import torch
5+
6+
from torch_geometric.testing import withPackage
7+
from torch_geometric.utils import from_relbench
8+
9+
10+
def _mock_table(
11+
df: pd.DataFrame,
12+
fkey_col_to_pkey_table: dict,
13+
pkey_col: str = None,
14+
time_col: str = None,
15+
) -> SimpleNamespace:
16+
"""Create a mock object that duck-types relbench.base.Table."""
17+
return SimpleNamespace(
18+
df=df,
19+
fkey_col_to_pkey_table=fkey_col_to_pkey_table,
20+
pkey_col=pkey_col,
21+
time_col=time_col,
22+
)
23+
24+
25+
def _mock_database(table_dict: dict) -> SimpleNamespace:
26+
"""Create a mock object that duck-types relbench.base.Database."""
27+
return SimpleNamespace(table_dict=table_dict)
28+
29+
30+
def test_from_relbench():
31+
df_users = pd.DataFrame({
32+
'id': [0, 1, 2],
33+
'age': [25, 30, 35],
34+
'score': [1.0, 2.0, 3.0],
35+
})
36+
df_posts = pd.DataFrame({
37+
'id': [0, 1, 2, 3],
38+
'user_id': [0, 1, 0, 2],
39+
'length': [100, 200, 150, 300],
40+
})
41+
42+
users = _mock_table(
43+
df=df_users,
44+
fkey_col_to_pkey_table={},
45+
pkey_col='id',
46+
)
47+
posts = _mock_table(
48+
df=df_posts,
49+
fkey_col_to_pkey_table={'user_id': 'users'},
50+
pkey_col='id',
51+
)
52+
53+
db = _mock_database(table_dict={'users': users, 'posts': posts})
54+
data = from_relbench(db)
55+
56+
# Verify node types:
57+
assert 'users' in data.node_types
58+
assert 'posts' in data.node_types
59+
60+
# Verify node counts:
61+
assert data['users'].num_nodes == 3
62+
assert data['posts'].num_nodes == 4
63+
64+
# Verify numeric features were extracted:
65+
assert data['users'].x is not None
66+
assert data['users'].x.size() == (3, 2) # age, score
67+
assert data['posts'].x is not None
68+
assert data['posts'].x.size() == (4, 1) # length
69+
70+
# Verify feature values:
71+
assert torch.allclose(
72+
data['users'].x,
73+
torch.tensor([[25, 1.0], [30, 2.0], [35, 3.0]]),
74+
)
75+
76+
# Verify edge types (bidirectional fkey edges):
77+
edge_types = data.edge_types
78+
assert ('posts', 'f2p_user_id', 'users') in edge_types
79+
assert ('users', 'rev_f2p_user_id', 'posts') in edge_types
80+
81+
# Verify edge index shapes (4 posts, each referencing a user):
82+
fwd = data['posts', 'f2p_user_id', 'users'].edge_index
83+
rev = data['users', 'rev_f2p_user_id', 'posts'].edge_index
84+
assert fwd.size() == (2, 4)
85+
assert rev.size() == (2, 4)
86+
87+
88+
def test_from_relbench_dangling_fkeys():
89+
"""Test that dangling (NaN) foreign keys are filtered out."""
90+
df_users = pd.DataFrame({'id': [0, 1]})
91+
df_posts = pd.DataFrame({
92+
'id': [0, 1, 2],
93+
'user_id':
94+
pd.array([0, None, 1], dtype=pd.Int64Dtype()),
95+
})
96+
97+
users = _mock_table(
98+
df=df_users,
99+
fkey_col_to_pkey_table={},
100+
pkey_col='id',
101+
)
102+
posts = _mock_table(
103+
df=df_posts,
104+
fkey_col_to_pkey_table={'user_id': 'users'},
105+
pkey_col='id',
106+
)
107+
108+
db = _mock_database(table_dict={'users': users, 'posts': posts})
109+
data = from_relbench(db)
110+
111+
# Only 2 out of 3 posts have valid foreign keys:
112+
fwd = data['posts', 'f2p_user_id', 'users'].edge_index
113+
assert fwd.size() == (2, 2)
114+
115+
116+
def test_from_relbench_time_column():
117+
"""Test that time columns are correctly converted."""
118+
df = pd.DataFrame({
119+
'id': [0, 1, 2],
120+
'ts':
121+
pd.to_datetime(['2024-01-01', '2024-01-02', '2024-01-03']),
122+
'val': [10, 20, 30],
123+
})
124+
125+
events = _mock_table(
126+
df=df,
127+
fkey_col_to_pkey_table={},
128+
pkey_col='id',
129+
time_col='ts',
130+
)
131+
132+
db = _mock_database(table_dict={'events': events})
133+
data = from_relbench(db)
134+
135+
assert data['events'].num_nodes == 3
136+
assert data['events'].time is not None
137+
assert data['events'].time.size() == (3, )
138+
# Time column should not appear in features:
139+
assert data['events'].x.size() == (3, 1) # only 'val'
140+
141+
142+
def test_from_relbench_no_features():
143+
"""Test tables with only pkey/fkey columns and no numeric features."""
144+
df = pd.DataFrame({
145+
'id': [0, 1, 2],
146+
'name': ['a', 'b', 'c'], # Non-numeric, should be excluded
147+
})
148+
149+
items = _mock_table(
150+
df=df,
151+
fkey_col_to_pkey_table={},
152+
pkey_col='id',
153+
)
154+
155+
db = _mock_database(table_dict={'items': items})
156+
data = from_relbench(db)
157+
158+
assert data['items'].num_nodes == 3
159+
# No numeric feature columns (name is string, id is pkey):
160+
assert not hasattr(data['items'], 'x') or data['items'].x is None
161+
162+
163+
@withPackage('relbench')
164+
def test_from_relbench_with_relbench():
165+
"""Integration test using actual relbench objects."""
166+
from relbench.base import Database, Table
167+
168+
df_users = pd.DataFrame({
169+
'id': [0, 1, 2],
170+
'age': [25, 30, 35],
171+
})
172+
df_posts = pd.DataFrame({
173+
'id': [0, 1, 2],
174+
'user_id': [0, 1, 0],
175+
'score': [10, 20, 30],
176+
})
177+
178+
users = Table(
179+
df=df_users,
180+
fkey_col_to_pkey_table={},
181+
pkey_col='id',
182+
)
183+
posts = Table(
184+
df=df_posts,
185+
fkey_col_to_pkey_table={'user_id': 'users'},
186+
pkey_col='id',
187+
)
188+
189+
db = Database(table_dict={'users': users, 'posts': posts})
190+
data = from_relbench(db)
191+
192+
assert 'users' in data.node_types
193+
assert 'posts' in data.node_types
194+
assert data['users'].num_nodes == 3
195+
assert data['posts'].num_nodes == 3

torch_geometric/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .convert import to_cugraph, from_cugraph
4747
from .convert import to_dgl, from_dgl
4848
from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles
49+
from .relbench import from_relbench
4950
from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
5051
barabasi_albert_graph)
5152
from ._negative_sampling import (negative_sampling, batched_negative_sampling,
@@ -135,6 +136,7 @@
135136
'to_rdmol',
136137
'from_smiles',
137138
'to_smiles',
139+
'from_relbench',
138140
'erdos_renyi_graph',
139141
'stochastic_blockmodel_graph',
140142
'barabasi_albert_graph',

torch_geometric/utils/relbench.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import torch
5+
6+
import torch_geometric
7+
from torch_geometric.data import HeteroData
8+
from torch_geometric.utils import sort_edge_index
9+
10+
11+
def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData':
12+
r"""Converts a :class:`relbench.base.Database` object into a
13+
:class:`~torch_geometric.data.HeteroData` object.
14+
15+
Each table in the database becomes a node type and each foreign key
16+
relationship becomes a bidirectional edge type.
17+
18+
Numeric columns (excluding primary key, foreign key, and time columns)
19+
are concatenated into a node feature tensor :obj:`x`. If a table contains
20+
a time column, it is stored as a :obj:`time` attribute.
21+
22+
Args:
23+
db (relbench.base.Database): A RelBench database instance containing
24+
a dictionary of tables linked by primary-foreign key
25+
relationships.
26+
27+
Returns:
28+
HeteroData: A heterogeneous graph where each table maps to a node
29+
type and each foreign key relationship maps to a pair of directed
30+
edge types.
31+
32+
Example:
33+
>>> from relbench.base import Database, Table
34+
>>> import pandas as pd
35+
>>> users = Table(
36+
... df=pd.DataFrame({'id': [0, 1, 2], 'age': [25, 30, 35]}),
37+
... fkey_col_to_pkey_table={},
38+
... pkey_col='id',
39+
... )
40+
>>> posts = Table(
41+
... df=pd.DataFrame({
42+
... 'id': [0, 1, 2],
43+
... 'user_id': [0, 1, 0],
44+
... 'score': [10, 20, 30],
45+
... }),
46+
... fkey_col_to_pkey_table={'user_id': 'users'},
47+
... pkey_col='id',
48+
... )
49+
>>> db = Database(table_dict={'users': users, 'posts': posts})
50+
>>> data = from_relbench(db)
51+
>>> data.node_types
52+
['users', 'posts']
53+
"""
54+
data = HeteroData()
55+
56+
for table_name, table in db.table_dict.items():
57+
df = table.df
58+
59+
# Determine columns to exclude from node features:
60+
exclude_cols = set()
61+
if table.pkey_col is not None:
62+
exclude_cols.add(table.pkey_col)
63+
if table.time_col is not None:
64+
exclude_cols.add(table.time_col)
65+
for fkey_col in table.fkey_col_to_pkey_table:
66+
exclude_cols.add(fkey_col)
67+
68+
# Set number of nodes:
69+
data[table_name].num_nodes = len(df)
70+
71+
# Convert numeric feature columns into a node feature tensor:
72+
feature_cols = [
73+
col for col in df.columns
74+
if col not in exclude_cols and df[col].dtype.kind in ('i', 'f')
75+
]
76+
if len(feature_cols) > 0:
77+
x = torch.from_numpy(df[feature_cols].values.astype(np.float32), )
78+
data[table_name].x = x
79+
80+
# Store time column as Unix timestamp tensor:
81+
if table.time_col is not None:
82+
time_ser = df[table.time_col]
83+
if time_ser.dtype in [
84+
np.dtype("datetime64[s]"),
85+
np.dtype("datetime64[ns]"),
86+
]:
87+
unix_time = time_ser.astype("int64").values
88+
if time_ser.dtype == np.dtype("datetime64[ns]"):
89+
unix_time = unix_time // 10**9
90+
data[table_name].time = torch.from_numpy(unix_time)
91+
else:
92+
data[table_name].time = torch.from_numpy(
93+
time_ser.values.astype(np.float64), )
94+
95+
# Create edges from foreign key relationships:
96+
for fkey_col, pkey_table_name in table.fkey_col_to_pkey_table.items():
97+
pkey_index = df[fkey_col]
98+
99+
# Filter out dangling (NaN) foreign keys:
100+
mask = ~pkey_index.isna()
101+
fkey_idx = torch.arange(len(pkey_index))
102+
pkey_idx = torch.from_numpy(pkey_index[mask].astype(int).values, )
103+
fkey_idx = fkey_idx[torch.from_numpy(mask.values)]
104+
105+
# Forward edge: fkey table -> pkey table
106+
edge_index = torch.stack([fkey_idx, pkey_idx], dim=0)
107+
edge_type = (table_name, f"f2p_{fkey_col}", pkey_table_name)
108+
data[edge_type].edge_index = sort_edge_index(edge_index)
109+
110+
# Reverse edge: pkey table -> fkey table
111+
edge_index = torch.stack([pkey_idx, fkey_idx], dim=0)
112+
edge_type = (pkey_table_name, f"rev_f2p_{fkey_col}", table_name)
113+
data[edge_type].edge_index = sort_edge_index(edge_index)
114+
115+
data.validate()
116+
117+
return data

0 commit comments

Comments
 (0)