From a2bf4d5aad21c95fa308374e27db708629b4afc7 Mon Sep 17 00:00:00 2001 From: AJamal27891 Date: Wed, 4 Mar 2026 17:37:24 +0200 Subject: [PATCH 01/14] Add from_relbench utility to convert RelBench databases to HeteroData --- CHANGELOG.md | 2 + test/utils/test_relbench.py | 208 ++++++++++++++++++++++++++++++ torch_geometric/utils/__init__.py | 2 + torch_geometric/utils/relbench.py | 121 +++++++++++++++++ 4 files changed, 333 insertions(+) create mode 100644 test/utils/test_relbench.py create mode 100644 torch_geometric/utils/relbench.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 27266095fa85..ef3697c83e54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added +- Added `from_relbench` utility to convert RelBench databases into `HeteroData` ([#10628](https://github.com/pyg-team/pytorch_geometric/pull/10628)) + ### Changed - Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) diff --git a/test/utils/test_relbench.py b/test/utils/test_relbench.py new file mode 100644 index 000000000000..36ae4e011f01 --- /dev/null +++ b/test/utils/test_relbench.py @@ -0,0 +1,208 @@ +from types import SimpleNamespace +from typing import Any, Optional + +import torch + +from torch_geometric.testing import withPackage +from torch_geometric.utils import from_relbench + + +def _mock_table( + df: Any, + fkey_col_to_pkey_table: dict, + pkey_col: Optional[str] = None, + time_col: Optional[str] = None, +) -> SimpleNamespace: + """Create a mock object that duck-types relbench.base.Table.""" + return SimpleNamespace( + df=df, + fkey_col_to_pkey_table=fkey_col_to_pkey_table, + pkey_col=pkey_col, + time_col=time_col, + ) + + +def _mock_database(table_dict: dict) -> SimpleNamespace: + """Create a mock object that duck-types relbench.base.Database.""" + return SimpleNamespace(table_dict=table_dict) + + +@withPackage('pandas') +def test_from_relbench(): + import pandas as pd + + df_users = pd.DataFrame({ + 'id': [0, 1, 2], + 'age': [25, 30, 35], + 'score': [1.0, 2.0, 3.0], + }) + df_posts = pd.DataFrame({ + 'id': [0, 1, 2, 3], + 'user_id': [0, 1, 0, 2], + 'length': [100, 200, 150, 300], + }) + + users = _mock_table( + df=df_users, + fkey_col_to_pkey_table={}, + pkey_col='id', + ) + posts = _mock_table( + df=df_posts, + fkey_col_to_pkey_table={'user_id': 'users'}, + pkey_col='id', + ) + + db = _mock_database(table_dict={'users': users, 'posts': posts}) + data = from_relbench(db) + + # Verify node types: + assert 'users' in data.node_types + assert 'posts' in data.node_types + + # Verify node counts: + assert data['users'].num_nodes == 3 + assert data['posts'].num_nodes == 4 + + # Verify numeric features were extracted: + assert data['users'].x is not None + assert data['users'].x.size() == (3, 2) # age, score + assert data['posts'].x is not None + assert data['posts'].x.size() == (4, 1) # length + + # Verify feature values: + assert torch.allclose( + data['users'].x, + torch.tensor([[25, 1.0], [30, 2.0], [35, 3.0]]), + ) + + # Verify edge types (bidirectional fkey edges): + edge_types = data.edge_types + assert ('posts', 'f2p_user_id', 'users') in edge_types + assert ('users', 'rev_f2p_user_id', 'posts') in edge_types + + # Verify edge index shapes (4 posts, each referencing a user): + fwd = data['posts', 'f2p_user_id', 'users'].edge_index + rev = data['users', 'rev_f2p_user_id', 'posts'].edge_index + assert fwd.size() == (2, 4) + assert rev.size() == (2, 4) + + +@withPackage('pandas') +def test_from_relbench_dangling_fkeys(): + """Test that dangling (NaN) foreign keys are filtered out.""" + import pandas as pd + + df_users = pd.DataFrame({'id': [0, 1]}) + df_posts = pd.DataFrame({ + 'id': [0, 1, 2], + 'user_id': + pd.array([0, None, 1], dtype=pd.Int64Dtype()), + }) + + users = _mock_table( + df=df_users, + fkey_col_to_pkey_table={}, + pkey_col='id', + ) + posts = _mock_table( + df=df_posts, + fkey_col_to_pkey_table={'user_id': 'users'}, + pkey_col='id', + ) + + db = _mock_database(table_dict={'users': users, 'posts': posts}) + data = from_relbench(db) + + # Only 2 out of 3 posts have valid foreign keys: + fwd = data['posts', 'f2p_user_id', 'users'].edge_index + assert fwd.size() == (2, 2) + + +@withPackage('pandas') +def test_from_relbench_time_column(): + """Test that time columns are correctly converted.""" + import pandas as pd + + df = pd.DataFrame({ + 'id': [0, 1, 2], + 'ts': + pd.to_datetime(['2024-01-01', '2024-01-02', '2024-01-03']), + 'val': [10, 20, 30], + }) + + events = _mock_table( + df=df, + fkey_col_to_pkey_table={}, + pkey_col='id', + time_col='ts', + ) + + db = _mock_database(table_dict={'events': events}) + data = from_relbench(db) + + assert data['events'].num_nodes == 3 + assert data['events'].time is not None + assert data['events'].time.size() == (3, ) + # Time column should not appear in features: + assert data['events'].x.size() == (3, 1) # only 'val' + + +@withPackage('pandas') +def test_from_relbench_no_features(): + """Test tables with only pkey/fkey columns and no numeric features.""" + import pandas as pd + + df = pd.DataFrame({ + 'id': [0, 1, 2], + 'name': ['a', 'b', 'c'], # Non-numeric, should be excluded + }) + + items = _mock_table( + df=df, + fkey_col_to_pkey_table={}, + pkey_col='id', + ) + + db = _mock_database(table_dict={'items': items}) + data = from_relbench(db) + + assert data['items'].num_nodes == 3 + # No numeric feature columns (name is string, id is pkey): + assert not hasattr(data['items'], 'x') or data['items'].x is None + + +@withPackage('relbench') +def test_from_relbench_with_relbench(): + """Integration test using actual relbench objects.""" + import pandas as pd + from relbench.base import Database, Table + + df_users = pd.DataFrame({ + 'id': [0, 1, 2], + 'age': [25, 30, 35], + }) + df_posts = pd.DataFrame({ + 'id': [0, 1, 2], + 'user_id': [0, 1, 0], + 'score': [10, 20, 30], + }) + + users = Table( + df=df_users, + fkey_col_to_pkey_table={}, + pkey_col='id', + ) + posts = Table( + df=df_posts, + fkey_col_to_pkey_table={'user_id': 'users'}, + pkey_col='id', + ) + + db = Database(table_dict={'users': users, 'posts': posts}) + data = from_relbench(db) + + assert 'users' in data.node_types + assert 'posts' in data.node_types + assert data['users'].num_nodes == 3 + assert data['posts'].num_nodes == 3 diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 9aea960c696f..0bc2817d6185 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -46,6 +46,7 @@ from .convert import to_cugraph, from_cugraph from .convert import to_dgl, from_dgl from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles +from .relbench import from_relbench from .random import (erdos_renyi_graph, stochastic_blockmodel_graph, barabasi_albert_graph) from ._negative_sampling import (negative_sampling, batched_negative_sampling, @@ -135,6 +136,7 @@ 'to_rdmol', 'from_smiles', 'to_smiles', + 'from_relbench', 'erdos_renyi_graph', 'stochastic_blockmodel_graph', 'barabasi_albert_graph', diff --git a/torch_geometric/utils/relbench.py b/torch_geometric/utils/relbench.py new file mode 100644 index 000000000000..181bbe9f6eca --- /dev/null +++ b/torch_geometric/utils/relbench.py @@ -0,0 +1,121 @@ +from typing import Any + +import numpy as np +import torch + +import torch_geometric +from torch_geometric.data import HeteroData +from torch_geometric.utils import sort_edge_index + + +def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData': + r"""Converts a :class:`relbench.base.Database` object into a + :class:`~torch_geometric.data.HeteroData` object. + + Each table in the database becomes a node type and each foreign key + relationship becomes a bidirectional edge type. + + Numeric columns (excluding primary key, foreign key, and time columns) + are concatenated into a node feature tensor :obj:`x`. If a table contains + a time column, it is stored as a :obj:`time` attribute. + + Args: + db (relbench.base.Database): A RelBench database instance containing + a dictionary of tables linked by primary-foreign key + relationships. + + Returns: + HeteroData: A heterogeneous graph where each table maps to a node + type and each foreign key relationship maps to a pair of directed + edge types. + + Example: + >>> from relbench.base import Database, Table + >>> import pandas as pd + >>> users = Table( + ... df=pd.DataFrame({'id': [0, 1, 2], 'age': [25, 30, 35]}), + ... fkey_col_to_pkey_table={}, + ... pkey_col='id', + ... ) + >>> posts = Table( + ... df=pd.DataFrame({ + ... 'id': [0, 1, 2], + ... 'user_id': [0, 1, 0], + ... 'score': [10, 20, 30], + ... }), + ... fkey_col_to_pkey_table={'user_id': 'users'}, + ... pkey_col='id', + ... ) + >>> db = Database(table_dict={'users': users, 'posts': posts}) + >>> data = from_relbench(db) + >>> data.node_types + ['users', 'posts'] + """ + data = HeteroData() + + for table_name, table in db.table_dict.items(): + df = table.df + + # Determine columns to exclude from node features: + exclude_cols = set() + if table.pkey_col is not None: + exclude_cols.add(table.pkey_col) + if table.time_col is not None: + exclude_cols.add(table.time_col) + for fkey_col in table.fkey_col_to_pkey_table: + exclude_cols.add(fkey_col) + + # Set number of nodes: + data[table_name].num_nodes = len(df) + + # Convert numeric feature columns into a node feature tensor: + feature_cols = [ + col for col in df.columns + if col not in exclude_cols and df[col].dtype.kind in ('i', 'f') + ] + if len(feature_cols) > 0: + x_np = df[feature_cols].to_numpy( + dtype=np.float32, + na_value=np.nan, + ) + data[table_name].x = torch.from_numpy(x_np) + + # Store time column as Unix timestamp tensor: + if table.time_col is not None: + time_ser = df[table.time_col] + if time_ser.dtype in [ + np.dtype("datetime64[s]"), + np.dtype("datetime64[ns]"), + ]: + unix_time = time_ser.astype("int64").values + if time_ser.dtype == np.dtype("datetime64[ns]"): + unix_time = unix_time // 10**9 + data[table_name].time = torch.from_numpy(unix_time) + else: + data[table_name].time = torch.from_numpy( + time_ser.values.astype(np.float64), ) + + # Create edges from foreign key relationships: + for fkey_col, pkey_table_name in table.fkey_col_to_pkey_table.items(): + pkey_index = df[fkey_col] + + # Filter out dangling (NaN) foreign keys: + mask = ~pkey_index.isna() + fkey_idx = torch.arange(len(pkey_index)) + pkey_idx = torch.from_numpy( + pkey_index[mask].to_numpy(dtype=np.int64), ) + fkey_idx = fkey_idx[torch.from_numpy(mask.to_numpy(dtype=bool))] + + # Forward edge: fkey table -> pkey table + edge_index = torch.stack([fkey_idx, pkey_idx], dim=0) + edge_type = (table_name, f"f2p_{fkey_col}", pkey_table_name) + data[edge_type].edge_index = sort_edge_index(edge_index) + + # Reverse edge: pkey table -> fkey table + edge_index = torch.stack([pkey_idx, fkey_idx], dim=0) + edge_type = (pkey_table_name, f"rev_f2p_{fkey_col}", table_name) + data[edge_type].edge_index = sort_edge_index(edge_index) + + data.validate() + + return data From 268b7d181b79b894eea6b059eb558049d63b4fc2 Mon Sep 17 00:00:00 2001 From: AJamal27891 Date: Sun, 8 Mar 2026 17:17:07 +0200 Subject: [PATCH 02/14] Add relbench_example.py to demonstrate from_relbench with heterogeneous GNN training --- examples/relbench_example.py | 147 +++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 examples/relbench_example.py diff --git a/examples/relbench_example.py b/examples/relbench_example.py new file mode 100644 index 000000000000..39120e05b982 --- /dev/null +++ b/examples/relbench_example.py @@ -0,0 +1,147 @@ +"""Example demonstrating how to use ``from_relbench`` to convert a RelBench +relational database into a PyG HeteroData graph and train a heterogeneous +GNN for node-level prediction. + +This example loads the Formula 1 RelBench dataset, converts it into a +heterogeneous graph using ``from_relbench``, and trains a 2-layer GraphSAGE +model (via ``to_hetero``) to predict championship standings points from +the graph structure and node features. + +Requirements: + ``pip install relbench`` + +Usage: + ``python relbench_example.py`` + ``python relbench_example.py --epochs 50 --hidden_channels 128`` +""" +import argparse + +import torch +import torch.nn.functional as F +from relbench.datasets import get_dataset + +from torch_geometric.nn import Linear, SAGEConv, to_hetero +from torch_geometric.utils import from_relbench + +parser = argparse.ArgumentParser( + description='Train a heterogeneous GNN on a RelBench dataset.') +parser.add_argument('--hidden_channels', type=int, default=64) +parser.add_argument('--lr', type=float, default=0.005) +parser.add_argument('--epochs', type=int, default=30) +args = parser.parse_args() + +torch.manual_seed(42) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# 1. Load a RelBench dataset and convert to HeteroData: +print('Loading RelBench rel-f1 dataset...') +dataset = get_dataset('rel-f1', download=True) +db = dataset.get_db() +data = from_relbench(db) +print(f'Graph: {len(data.node_types)} node types, ' + f'{len(data.edge_types)} edge types') + +# 2. Prepare a node regression target. +# `from_relbench` preserves the original DataFrame column order from RelBench. +# In rel-f1, the 'standings' table has 'points' as its first numeric column: +target_type = 'standings' +y = data[target_type].x[:, 0].clone() # points column (index 0 in rel-f1) +data[target_type].x = data[target_type].x[:, 1:] # remove from input features + +# 3. Clean up features — fill NaN and standardize per column: +for node_type in data.node_types: + if hasattr(data[node_type], 'x') and data[node_type].x is not None: + x = torch.nan_to_num(data[node_type].x, nan=0.0) + std, mean = torch.std_mean(x, dim=0) + std[std == 0] = 1.0 # avoid division by zero for constant columns + data[node_type].x = (x - mean) / std + else: + # Zero-feature placeholder for featureless node types (e.g. drivers): + data[node_type].x = torch.zeros(data[node_type].num_nodes, 1) + +# 4. Create train/val/test splits (60/20/20) before computing target stats: +num_nodes = data[target_type].num_nodes +perm = torch.randperm(num_nodes) +train_mask = torch.zeros(num_nodes, dtype=torch.bool) +val_mask = torch.zeros(num_nodes, dtype=torch.bool) +test_mask = torch.zeros(num_nodes, dtype=torch.bool) +train_mask[perm[:int(0.6 * num_nodes)]] = True +val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True +test_mask[perm[int(0.8 * num_nodes):]] = True + +# Normalize target using training set statistics only (prevents data leakage): +y_mean = y[train_mask].mean() +y_std = y[train_mask].std() +y_std = y_std if y_std > 0 else torch.tensor(1.0) +y_norm = (y - y_mean) / y_std + +# 5. Move all tensors to device — including masks to prevent device mismatch: +data = data.to(device) +y = y.to(device) +y_norm = y_norm.to(device) +train_mask = train_mask.to(device) +val_mask = val_mask.to(device) +test_mask = test_mask.to(device) + + +# 6. Define a 2-layer GraphSAGE model with lazy input size inference: +class GNN(torch.nn.Module): + def __init__(self, hidden_channels: int) -> None: + super().__init__() + self.conv1 = SAGEConv((-1, -1), hidden_channels) + self.conv2 = SAGEConv((-1, -1), hidden_channels) + self.lin = Linear(-1, 1) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index).relu() + x = self.conv2(x, edge_index).relu() + return self.lin(x) + + +model = GNN(args.hidden_channels) +model = to_hetero(model, data.metadata(), aggr='sum').to(device) + +# Initialize lazy parameters via a single dry-run forward pass: +with torch.no_grad(): + model(data.x_dict, data.edge_index_dict) + +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + +def train() -> float: + model.train() + optimizer.zero_grad() + pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1) + loss = F.mse_loss(pred[train_mask], y_norm[train_mask]) + loss.backward() + optimizer.step() + return float(loss) + + +@torch.no_grad() +def test(): + model.eval() + pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1) + pred_orig = pred * y_std + y_mean # denormalize for interpretable MAE + + train_mae = float((pred_orig[train_mask] - y[train_mask]).abs().mean()) + val_mae = float((pred_orig[val_mask] - y[val_mask]).abs().mean()) + test_mae = float((pred_orig[test_mask] - y[test_mask]).abs().mean()) + return train_mae, val_mae, test_mae + + +print( + f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') +print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n') + +for epoch in range(1, args.epochs + 1): + loss = train() + if epoch % 5 == 0 or epoch == 1: + train_mae, val_mae, test_mae = test() + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' + f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' + f'Test MAE: {test_mae:.2f} points') + +train_mae, val_mae, test_mae = test() +print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' + f'Test MAE: {test_mae:.2f} points') From df079debf0d4997899a59d03817b4b6697597343 Mon Sep 17 00:00:00 2001 From: AJamal27891 Date: Wed, 25 Mar 2026 21:13:22 +0200 Subject: [PATCH 03/14] Address akihironitta review: annotations, idioms, style, drop string forward ref --- examples/relbench_example.py | 58 +++++++++++++++++++------------ torch_geometric/utils/relbench.py | 28 ++++++++------- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/examples/relbench_example.py b/examples/relbench_example.py index 39120e05b982..bb525949d936 100644 --- a/examples/relbench_example.py +++ b/examples/relbench_example.py @@ -14,7 +14,9 @@ ``python relbench_example.py`` ``python relbench_example.py --epochs 50 --hidden_channels 128`` """ + import argparse +from typing import Tuple import torch import torch.nn.functional as F @@ -24,7 +26,8 @@ from torch_geometric.utils import from_relbench parser = argparse.ArgumentParser( - description='Train a heterogeneous GNN on a RelBench dataset.') + description='Train a heterogeneous GNN on a RelBench dataset.' +) parser.add_argument('--hidden_channels', type=int, default=64) parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--epochs', type=int, default=30) @@ -38,14 +41,16 @@ dataset = get_dataset('rel-f1', download=True) db = dataset.get_db() data = from_relbench(db) -print(f'Graph: {len(data.node_types)} node types, ' - f'{len(data.edge_types)} edge types') +print( + f'Graph: {len(data.node_types)} node types, ' + f'{len(data.edge_types)} edge types' +) # 2. Prepare a node regression target. # `from_relbench` preserves the original DataFrame column order from RelBench. # In rel-f1, the 'standings' table has 'points' as its first numeric column: target_type = 'standings' -y = data[target_type].x[:, 0].clone() # points column (index 0 in rel-f1) +y = data[target_type].x[:, 0] # points column (index 0 in rel-f1) data[target_type].x = data[target_type].x[:, 1:] # remove from input features # 3. Clean up features — fill NaN and standardize per column: @@ -65,9 +70,9 @@ train_mask = torch.zeros(num_nodes, dtype=torch.bool) val_mask = torch.zeros(num_nodes, dtype=torch.bool) test_mask = torch.zeros(num_nodes, dtype=torch.bool) -train_mask[perm[:int(0.6 * num_nodes)]] = True -val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True -test_mask[perm[int(0.8 * num_nodes):]] = True +train_mask[perm[: int(0.6 * num_nodes)]] = True +val_mask[perm[int(0.6 * num_nodes) : int(0.8 * num_nodes)]] = True +test_mask[perm[int(0.8 * num_nodes) :]] = True # Normalize target using training set statistics only (prevents data leakage): y_mean = y[train_mask].mean() @@ -92,7 +97,11 @@ def __init__(self, hidden_channels: int) -> None: self.conv2 = SAGEConv((-1, -1), hidden_channels) self.lin = Linear(-1, 1) - def forward(self, x, edge_index): + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return self.lin(x) @@ -108,40 +117,45 @@ def forward(self, x, edge_index): optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) -def train() -> float: +def train() -> torch.Tensor: model.train() optimizer.zero_grad() pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1) loss = F.mse_loss(pred[train_mask], y_norm[train_mask]) loss.backward() optimizer.step() - return float(loss) + return loss @torch.no_grad() -def test(): +def test() -> Tuple[float, float, float]: model.eval() pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1) - pred_orig = pred * y_std + y_mean # denormalize for interpretable MAE + # denormalize for interpretable MAE + pred *= y_std + pred += y_mean - train_mae = float((pred_orig[train_mask] - y[train_mask]).abs().mean()) - val_mae = float((pred_orig[val_mask] - y[val_mask]).abs().mean()) - test_mae = float((pred_orig[test_mask] - y[test_mask]).abs().mean()) + train_mae = float((pred[train_mask] - y[train_mask]).abs().mean()) + val_mae = float((pred[val_mask] - y[val_mask]).abs().mean()) + test_mae = float((pred[test_mask] - y[test_mask]).abs().mean()) return train_mae, val_mae, test_mae -print( - f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') +print(f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n') for epoch in range(1, args.epochs + 1): loss = train() if epoch % 5 == 0 or epoch == 1: train_mae, val_mae, test_mae = test() - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' - f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' - f'Test MAE: {test_mae:.2f} points') + print( + f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' + f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' + f'Test MAE: {test_mae:.2f} points' + ) train_mae, val_mae, test_mae = test() -print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' - f'Test MAE: {test_mae:.2f} points') +print( + f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' + f'Test MAE: {test_mae:.2f} points' +) diff --git a/torch_geometric/utils/relbench.py b/torch_geometric/utils/relbench.py index 181bbe9f6eca..ef3db4caf9bd 100644 --- a/torch_geometric/utils/relbench.py +++ b/torch_geometric/utils/relbench.py @@ -3,12 +3,11 @@ import numpy as np import torch -import torch_geometric from torch_geometric.data import HeteroData from torch_geometric.utils import sort_edge_index -def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData': +def from_relbench(db: Any) -> HeteroData: r"""Converts a :class:`relbench.base.Database` object into a :class:`~torch_geometric.data.HeteroData` object. @@ -29,7 +28,7 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData': type and each foreign key relationship maps to a pair of directed edge types. - Example: + Examples: >>> from relbench.base import Database, Table >>> import pandas as pd >>> users = Table( @@ -70,10 +69,11 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData': # Convert numeric feature columns into a node feature tensor: feature_cols = [ - col for col in df.columns + col + for col in df.columns if col not in exclude_cols and df[col].dtype.kind in ('i', 'f') ] - if len(feature_cols) > 0: + if feature_cols: x_np = df[feature_cols].to_numpy( dtype=np.float32, na_value=np.nan, @@ -84,16 +84,17 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData': if table.time_col is not None: time_ser = df[table.time_col] if time_ser.dtype in [ - np.dtype("datetime64[s]"), - np.dtype("datetime64[ns]"), + np.dtype('datetime64[s]'), + np.dtype('datetime64[ns]'), ]: - unix_time = time_ser.astype("int64").values - if time_ser.dtype == np.dtype("datetime64[ns]"): + unix_time = time_ser.astype('int64').values + if time_ser.dtype == np.dtype('datetime64[ns]'): unix_time = unix_time // 10**9 data[table_name].time = torch.from_numpy(unix_time) else: data[table_name].time = torch.from_numpy( - time_ser.values.astype(np.float64), ) + time_ser.values.astype(np.float64) + ) # Create edges from foreign key relationships: for fkey_col, pkey_table_name in table.fkey_col_to_pkey_table.items(): @@ -103,17 +104,18 @@ def from_relbench(db: Any) -> 'torch_geometric.data.HeteroData': mask = ~pkey_index.isna() fkey_idx = torch.arange(len(pkey_index)) pkey_idx = torch.from_numpy( - pkey_index[mask].to_numpy(dtype=np.int64), ) + pkey_index[mask].to_numpy(dtype=np.int64) + ) fkey_idx = fkey_idx[torch.from_numpy(mask.to_numpy(dtype=bool))] # Forward edge: fkey table -> pkey table edge_index = torch.stack([fkey_idx, pkey_idx], dim=0) - edge_type = (table_name, f"f2p_{fkey_col}", pkey_table_name) + edge_type = (table_name, f'f2p_{fkey_col}', pkey_table_name) data[edge_type].edge_index = sort_edge_index(edge_index) # Reverse edge: pkey table -> fkey table edge_index = torch.stack([pkey_idx, fkey_idx], dim=0) - edge_type = (pkey_table_name, f"rev_f2p_{fkey_col}", table_name) + edge_type = (pkey_table_name, f'rev_f2p_{fkey_col}', table_name) data[edge_type].edge_index = sort_edge_index(edge_index) data.validate() From 5be0a14287d90d8bc92caddf8924225db6879d32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Mar 2026 02:47:09 +0000 Subject: [PATCH 04/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/relbench_example.py | 32 +++++++++++++------------------ torch_geometric/utils/relbench.py | 13 +++++-------- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/examples/relbench_example.py b/examples/relbench_example.py index bb525949d936..a1f98f4d293d 100644 --- a/examples/relbench_example.py +++ b/examples/relbench_example.py @@ -26,8 +26,7 @@ from torch_geometric.utils import from_relbench parser = argparse.ArgumentParser( - description='Train a heterogeneous GNN on a RelBench dataset.' -) + description='Train a heterogeneous GNN on a RelBench dataset.') parser.add_argument('--hidden_channels', type=int, default=64) parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--epochs', type=int, default=30) @@ -41,10 +40,8 @@ dataset = get_dataset('rel-f1', download=True) db = dataset.get_db() data = from_relbench(db) -print( - f'Graph: {len(data.node_types)} node types, ' - f'{len(data.edge_types)} edge types' -) +print(f'Graph: {len(data.node_types)} node types, ' + f'{len(data.edge_types)} edge types') # 2. Prepare a node regression target. # `from_relbench` preserves the original DataFrame column order from RelBench. @@ -70,9 +67,9 @@ train_mask = torch.zeros(num_nodes, dtype=torch.bool) val_mask = torch.zeros(num_nodes, dtype=torch.bool) test_mask = torch.zeros(num_nodes, dtype=torch.bool) -train_mask[perm[: int(0.6 * num_nodes)]] = True -val_mask[perm[int(0.6 * num_nodes) : int(0.8 * num_nodes)]] = True -test_mask[perm[int(0.8 * num_nodes) :]] = True +train_mask[perm[:int(0.6 * num_nodes)]] = True +val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True +test_mask[perm[int(0.8 * num_nodes):]] = True # Normalize target using training set statistics only (prevents data leakage): y_mean = y[train_mask].mean() @@ -141,21 +138,18 @@ def test() -> Tuple[float, float, float]: return train_mae, val_mae, test_mae -print(f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') +print( + f'\nTraining {args.epochs} epochs on "{target_type}" point prediction...') print(f'Target stats (train): mean={y_mean:.2f}, std={y_std:.2f}\n') for epoch in range(1, args.epochs + 1): loss = train() if epoch % 5 == 0 or epoch == 1: train_mae, val_mae, test_mae = test() - print( - f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' - f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' - f'Test MAE: {test_mae:.2f} points' - ) + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' + f'Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' + f'Test MAE: {test_mae:.2f} points') train_mae, val_mae, test_mae = test() -print( - f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' - f'Test MAE: {test_mae:.2f} points' -) +print(f'\nFinal — Train MAE: {train_mae:.2f}, Val MAE: {val_mae:.2f}, ' + f'Test MAE: {test_mae:.2f} points') diff --git a/torch_geometric/utils/relbench.py b/torch_geometric/utils/relbench.py index ef3db4caf9bd..9ea4afad85a7 100644 --- a/torch_geometric/utils/relbench.py +++ b/torch_geometric/utils/relbench.py @@ -69,8 +69,7 @@ def from_relbench(db: Any) -> HeteroData: # Convert numeric feature columns into a node feature tensor: feature_cols = [ - col - for col in df.columns + col for col in df.columns if col not in exclude_cols and df[col].dtype.kind in ('i', 'f') ] if feature_cols: @@ -84,8 +83,8 @@ def from_relbench(db: Any) -> HeteroData: if table.time_col is not None: time_ser = df[table.time_col] if time_ser.dtype in [ - np.dtype('datetime64[s]'), - np.dtype('datetime64[ns]'), + np.dtype('datetime64[s]'), + np.dtype('datetime64[ns]'), ]: unix_time = time_ser.astype('int64').values if time_ser.dtype == np.dtype('datetime64[ns]'): @@ -93,8 +92,7 @@ def from_relbench(db: Any) -> HeteroData: data[table_name].time = torch.from_numpy(unix_time) else: data[table_name].time = torch.from_numpy( - time_ser.values.astype(np.float64) - ) + time_ser.values.astype(np.float64)) # Create edges from foreign key relationships: for fkey_col, pkey_table_name in table.fkey_col_to_pkey_table.items(): @@ -104,8 +102,7 @@ def from_relbench(db: Any) -> HeteroData: mask = ~pkey_index.isna() fkey_idx = torch.arange(len(pkey_index)) pkey_idx = torch.from_numpy( - pkey_index[mask].to_numpy(dtype=np.int64) - ) + pkey_index[mask].to_numpy(dtype=np.int64)) fkey_idx = fkey_idx[torch.from_numpy(mask.to_numpy(dtype=bool))] # Forward edge: fkey table -> pkey table From a33b35cfe872d29134255faf844472137c08fcc7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 3 Jun 2026 22:11:00 -0700 Subject: [PATCH 05/14] update --- examples/relbench_example.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/examples/relbench_example.py b/examples/relbench_example.py index a1f98f4d293d..837d3ad75896 100644 --- a/examples/relbench_example.py +++ b/examples/relbench_example.py @@ -6,13 +6,6 @@ heterogeneous graph using ``from_relbench``, and trains a 2-layer GraphSAGE model (via ``to_hetero``) to predict championship standings points from the graph structure and node features. - -Requirements: - ``pip install relbench`` - -Usage: - ``python relbench_example.py`` - ``python relbench_example.py --epochs 50 --hidden_channels 128`` """ import argparse @@ -39,7 +32,7 @@ print('Loading RelBench rel-f1 dataset...') dataset = get_dataset('rel-f1', download=True) db = dataset.get_db() -data = from_relbench(db) +data = from_relbench(db).to(device) print(f'Graph: {len(data.node_types)} node types, ' f'{len(data.edge_types)} edge types') @@ -47,7 +40,7 @@ # `from_relbench` preserves the original DataFrame column order from RelBench. # In rel-f1, the 'standings' table has 'points' as its first numeric column: target_type = 'standings' -y = data[target_type].x[:, 0] # points column (index 0 in rel-f1) +y = data[target_type].x[:, 0].to(device) # points column (index 0 in rel-f1) data[target_type].x = data[target_type].x[:, 1:] # remove from input features # 3. Clean up features — fill NaN and standardize per column: @@ -59,34 +52,26 @@ data[node_type].x = (x - mean) / std else: # Zero-feature placeholder for featureless node types (e.g. drivers): - data[node_type].x = torch.zeros(data[node_type].num_nodes, 1) + data[node_type].x = torch.zeros(data[node_type].num_nodes, 1, + device=device) # 4. Create train/val/test splits (60/20/20) before computing target stats: num_nodes = data[target_type].num_nodes -perm = torch.randperm(num_nodes) -train_mask = torch.zeros(num_nodes, dtype=torch.bool) -val_mask = torch.zeros(num_nodes, dtype=torch.bool) -test_mask = torch.zeros(num_nodes, dtype=torch.bool) +perm = torch.randperm(num_nodes, device=device) +train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device) +val_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device) +test_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device) train_mask[perm[:int(0.6 * num_nodes)]] = True val_mask[perm[int(0.6 * num_nodes):int(0.8 * num_nodes)]] = True test_mask[perm[int(0.8 * num_nodes):]] = True # Normalize target using training set statistics only (prevents data leakage): y_mean = y[train_mask].mean() -y_std = y[train_mask].std() -y_std = y_std if y_std > 0 else torch.tensor(1.0) +y_std = max(y[train_mask].std(), 1e-10) y_norm = (y - y_mean) / y_std -# 5. Move all tensors to device — including masks to prevent device mismatch: -data = data.to(device) -y = y.to(device) -y_norm = y_norm.to(device) -train_mask = train_mask.to(device) -val_mask = val_mask.to(device) -test_mask = test_mask.to(device) - -# 6. Define a 2-layer GraphSAGE model with lazy input size inference: +# 5. Define a 2-layer GraphSAGE model with lazy input size inference: class GNN(torch.nn.Module): def __init__(self, hidden_channels: int) -> None: super().__init__() From fa4462fb7fff2b8ec7f8c96b43a5b208be8f47fa Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 3 Jun 2026 22:16:18 -0700 Subject: [PATCH 06/14] update --- test/utils/test_relbench.py | 93 ++++++++----------------------------- 1 file changed, 20 insertions(+), 73 deletions(-) diff --git a/test/utils/test_relbench.py b/test/utils/test_relbench.py index 36ae4e011f01..5f088591322a 100644 --- a/test/utils/test_relbench.py +++ b/test/utils/test_relbench.py @@ -1,35 +1,13 @@ -from types import SimpleNamespace -from typing import Any, Optional - import torch from torch_geometric.testing import withPackage from torch_geometric.utils import from_relbench -def _mock_table( - df: Any, - fkey_col_to_pkey_table: dict, - pkey_col: Optional[str] = None, - time_col: Optional[str] = None, -) -> SimpleNamespace: - """Create a mock object that duck-types relbench.base.Table.""" - return SimpleNamespace( - df=df, - fkey_col_to_pkey_table=fkey_col_to_pkey_table, - pkey_col=pkey_col, - time_col=time_col, - ) - - -def _mock_database(table_dict: dict) -> SimpleNamespace: - """Create a mock object that duck-types relbench.base.Database.""" - return SimpleNamespace(table_dict=table_dict) - - -@withPackage('pandas') +@withPackage('relbench') def test_from_relbench(): import pandas as pd + from relbench.base import Database, Table df_users = pd.DataFrame({ 'id': [0, 1, 2], @@ -42,18 +20,18 @@ def test_from_relbench(): 'length': [100, 200, 150, 300], }) - users = _mock_table( + users = Table( df=df_users, fkey_col_to_pkey_table={}, pkey_col='id', ) - posts = _mock_table( + posts = Table( df=df_posts, fkey_col_to_pkey_table={'user_id': 'users'}, pkey_col='id', ) - db = _mock_database(table_dict={'users': users, 'posts': posts}) + db = Database(table_dict={'users': users, 'posts': posts}) data = from_relbench(db) # Verify node types: @@ -88,10 +66,11 @@ def test_from_relbench(): assert rev.size() == (2, 4) -@withPackage('pandas') +@withPackage('relbench') def test_from_relbench_dangling_fkeys(): """Test that dangling (NaN) foreign keys are filtered out.""" import pandas as pd + from relbench.base import Database, Table df_users = pd.DataFrame({'id': [0, 1]}) df_posts = pd.DataFrame({ @@ -100,18 +79,18 @@ def test_from_relbench_dangling_fkeys(): pd.array([0, None, 1], dtype=pd.Int64Dtype()), }) - users = _mock_table( + users = Table( df=df_users, fkey_col_to_pkey_table={}, pkey_col='id', ) - posts = _mock_table( + posts = Table( df=df_posts, fkey_col_to_pkey_table={'user_id': 'users'}, pkey_col='id', ) - db = _mock_database(table_dict={'users': users, 'posts': posts}) + db = Database(table_dict={'users': users, 'posts': posts}) data = from_relbench(db) # Only 2 out of 3 posts have valid foreign keys: @@ -119,10 +98,11 @@ def test_from_relbench_dangling_fkeys(): assert fwd.size() == (2, 2) -@withPackage('pandas') +@withPackage('relbench') def test_from_relbench_time_column(): """Test that time columns are correctly converted.""" import pandas as pd + from relbench.base import Database, Table df = pd.DataFrame({ 'id': [0, 1, 2], @@ -131,78 +111,45 @@ def test_from_relbench_time_column(): 'val': [10, 20, 30], }) - events = _mock_table( + events = Table( df=df, fkey_col_to_pkey_table={}, pkey_col='id', time_col='ts', ) - db = _mock_database(table_dict={'events': events}) + db = Database(table_dict={'events': events}) data = from_relbench(db) assert data['events'].num_nodes == 3 assert data['events'].time is not None assert data['events'].time.size() == (3, ) + # Verify Unix timestamps (seconds) for the three dates: + assert data['events'].time.tolist() == [1704067200, 1704153600, 1704240000] # Time column should not appear in features: assert data['events'].x.size() == (3, 1) # only 'val' -@withPackage('pandas') +@withPackage('relbench') def test_from_relbench_no_features(): """Test tables with only pkey/fkey columns and no numeric features.""" import pandas as pd + from relbench.base import Database, Table df = pd.DataFrame({ 'id': [0, 1, 2], 'name': ['a', 'b', 'c'], # Non-numeric, should be excluded }) - items = _mock_table( + items = Table( df=df, fkey_col_to_pkey_table={}, pkey_col='id', ) - db = _mock_database(table_dict={'items': items}) + db = Database(table_dict={'items': items}) data = from_relbench(db) assert data['items'].num_nodes == 3 # No numeric feature columns (name is string, id is pkey): assert not hasattr(data['items'], 'x') or data['items'].x is None - - -@withPackage('relbench') -def test_from_relbench_with_relbench(): - """Integration test using actual relbench objects.""" - import pandas as pd - from relbench.base import Database, Table - - df_users = pd.DataFrame({ - 'id': [0, 1, 2], - 'age': [25, 30, 35], - }) - df_posts = pd.DataFrame({ - 'id': [0, 1, 2], - 'user_id': [0, 1, 0], - 'score': [10, 20, 30], - }) - - users = Table( - df=df_users, - fkey_col_to_pkey_table={}, - pkey_col='id', - ) - posts = Table( - df=df_posts, - fkey_col_to_pkey_table={'user_id': 'users'}, - pkey_col='id', - ) - - db = Database(table_dict={'users': users, 'posts': posts}) - data = from_relbench(db) - - assert 'users' in data.node_types - assert 'posts' in data.node_types - assert data['users'].num_nodes == 3 - assert data['posts'].num_nodes == 3 From dc04b6c38b87e971cf23e0a88567102d0fc136c5 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 3 Jun 2026 22:17:40 -0700 Subject: [PATCH 07/14] Rename relbench.py to _relbench.py --- torch_geometric/utils/__init__.py | 2 +- torch_geometric/utils/{relbench.py => _relbench.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename torch_geometric/utils/{relbench.py => _relbench.py} (100%) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 0bc2817d6185..53c889ff6b50 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -46,7 +46,7 @@ from .convert import to_cugraph, from_cugraph from .convert import to_dgl, from_dgl from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles -from .relbench import from_relbench +from ._relbench import from_relbench from .random import (erdos_renyi_graph, stochastic_blockmodel_graph, barabasi_albert_graph) from ._negative_sampling import (negative_sampling, batched_negative_sampling, diff --git a/torch_geometric/utils/relbench.py b/torch_geometric/utils/_relbench.py similarity index 100% rename from torch_geometric/utils/relbench.py rename to torch_geometric/utils/_relbench.py From 74653cb3e20ee0292576ffe1d9a2dbe4aebae43f Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 4 Jun 2026 10:02:18 -0700 Subject: [PATCH 08/14] Move from_relbench utility into torch_geometric/contrib/utils/ --- examples/relbench_example.py | 2 +- test/{ => contrib}/utils/test_relbench.py | 2 +- torch_geometric/contrib/__init__.py | 1 + torch_geometric/contrib/utils/__init__.py | 3 +++ torch_geometric/{ => contrib}/utils/_relbench.py | 0 torch_geometric/utils/__init__.py | 2 -- 6 files changed, 6 insertions(+), 4 deletions(-) rename test/{ => contrib}/utils/test_relbench.py (98%) create mode 100644 torch_geometric/contrib/utils/__init__.py rename torch_geometric/{ => contrib}/utils/_relbench.py (100%) diff --git a/examples/relbench_example.py b/examples/relbench_example.py index 837d3ad75896..b110d1255d4d 100644 --- a/examples/relbench_example.py +++ b/examples/relbench_example.py @@ -15,8 +15,8 @@ import torch.nn.functional as F from relbench.datasets import get_dataset +from torch_geometric.contrib.utils import from_relbench from torch_geometric.nn import Linear, SAGEConv, to_hetero -from torch_geometric.utils import from_relbench parser = argparse.ArgumentParser( description='Train a heterogeneous GNN on a RelBench dataset.') diff --git a/test/utils/test_relbench.py b/test/contrib/utils/test_relbench.py similarity index 98% rename from test/utils/test_relbench.py rename to test/contrib/utils/test_relbench.py index 5f088591322a..3260995368f4 100644 --- a/test/utils/test_relbench.py +++ b/test/contrib/utils/test_relbench.py @@ -1,7 +1,7 @@ import torch +from torch_geometric.contrib.utils import from_relbench from torch_geometric.testing import withPackage -from torch_geometric.utils import from_relbench @withPackage('relbench') diff --git a/torch_geometric/contrib/__init__.py b/torch_geometric/contrib/__init__.py index 06eda57c8048..b727148ed7df 100644 --- a/torch_geometric/contrib/__init__.py +++ b/torch_geometric/contrib/__init__.py @@ -4,6 +4,7 @@ import torch_geometric.contrib.datasets # noqa import torch_geometric.contrib.nn # noqa import torch_geometric.contrib.explain # noqa +import torch_geometric.contrib.utils # noqa warnings.warn( "'torch_geometric.contrib' contains experimental code and is subject to " diff --git a/torch_geometric/contrib/utils/__init__.py b/torch_geometric/contrib/utils/__init__.py new file mode 100644 index 000000000000..2b49f5944841 --- /dev/null +++ b/torch_geometric/contrib/utils/__init__.py @@ -0,0 +1,3 @@ +from ._relbench import from_relbench + +__all__ = ['from_relbench'] diff --git a/torch_geometric/utils/_relbench.py b/torch_geometric/contrib/utils/_relbench.py similarity index 100% rename from torch_geometric/utils/_relbench.py rename to torch_geometric/contrib/utils/_relbench.py diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 53c889ff6b50..9aea960c696f 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -46,7 +46,6 @@ from .convert import to_cugraph, from_cugraph from .convert import to_dgl, from_dgl from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles -from ._relbench import from_relbench from .random import (erdos_renyi_graph, stochastic_blockmodel_graph, barabasi_albert_graph) from ._negative_sampling import (negative_sampling, batched_negative_sampling, @@ -136,7 +135,6 @@ 'to_rdmol', 'from_smiles', 'to_smiles', - 'from_relbench', 'erdos_renyi_graph', 'stochastic_blockmodel_graph', 'barabasi_albert_graph', From b533422fad97cb7d6f1b3c76af30b4d8ee490d71 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 4 Jun 2026 10:03:11 -0700 Subject: [PATCH 09/14] Use built-in tuple instead of typing.Tuple in relbench example --- examples/relbench_example.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/relbench_example.py b/examples/relbench_example.py index b110d1255d4d..c24cf61d6293 100644 --- a/examples/relbench_example.py +++ b/examples/relbench_example.py @@ -9,7 +9,6 @@ """ import argparse -from typing import Tuple import torch import torch.nn.functional as F @@ -110,7 +109,7 @@ def train() -> torch.Tensor: @torch.no_grad() -def test() -> Tuple[float, float, float]: +def test() -> tuple[float, float, float]: model.eval() pred = model(data.x_dict, data.edge_index_dict)[target_type].squeeze(-1) # denormalize for interpretable MAE From 20b06e7a22023ba4ca74f2c412233e0f7307be39 Mon Sep 17 00:00:00 2001 From: AJamal27891 Date: Wed, 29 Apr 2026 20:44:44 +0300 Subject: [PATCH 10/14] Add PR 2 GRetriever RelBench example --- examples/llm/relbench_gretriever.py | 276 ++++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 examples/llm/relbench_gretriever.py diff --git a/examples/llm/relbench_gretriever.py b/examples/llm/relbench_gretriever.py new file mode 100644 index 000000000000..8e2ed9891e45 --- /dev/null +++ b/examples/llm/relbench_gretriever.py @@ -0,0 +1,276 @@ +"""Example demonstrating how to bridge ``from_relbench`` heterogeneous +graphs to GRetriever for graph-augmented question answering. + +This example loads the Formula 1 RelBench dataset, sanitizes the data, +projects all node types into a shared latent space (handling featureless +structural tables via learned embeddings), converts to homogeneous format, +and feeds the result into GRetriever. + +.. note:: + Calling ``to_homogeneous()`` directly on RelBench data silently + drops ALL node features (``x=None``) when any table lacks numeric + columns. This example shows the correct pattern: sanitize, project + all types to a common dimension, then convert. + +.. note:: + Due to a known upstream issue in PyG ``llm.py`` with + ``transformers >= 5.0``, this example currently requires + ``transformers 4.x``. + (``pip install "transformers>=4.51,<5.0"``) + +Requirements: + ``pip install relbench "transformers>=4.51,<5.0" sentencepiece + accelerate`` + +Usage: + ``python relbench_gretriever.py`` + ``python relbench_gretriever.py --epochs 10 --llm Qwen/Qwen2-0.5B`` +""" +import argparse + +import torch +import torch.nn as nn +from relbench.datasets import get_dataset + +from torch_geometric.llm.models import GRetriever, LLM +from torch_geometric.nn import GAT, HeteroDictLinear +from torch_geometric.utils import from_relbench + +# ── CLI ────────────────────────────────────────────────────────────── +parser = argparse.ArgumentParser( + description='RelBench -> GRetriever example.') +parser.add_argument('--dataset', type=str, default='rel-f1', + help='RelBench dataset name (default: rel-f1)') +parser.add_argument('--llm', type=str, + default='Qwen/Qwen2-0.5B', + help='HuggingFace LLM model name') +parser.add_argument('--hidden', type=int, default=64, + help='Common projection + GNN hidden dim') +parser.add_argument('--gnn_layers', type=int, default=2, + help='Number of GAT layers') +parser.add_argument('--epochs', type=int, default=5, + help='Training epochs') +parser.add_argument('--lr', type=float, default=1e-4, + help='Learning rate') +parser.add_argument('--dtype', type=str, default='bfloat16', + choices=['float32', 'bfloat16', 'float16'], + help='LLM dtype (use float32 for CPU-only)') +parser.add_argument('--n_gpus', type=int, default=1, + help='Number of GPUs for the LLM (0 for CPU)') +args = parser.parse_args() + +_dtype_map = { + 'float32': torch.float32, + 'bfloat16': torch.bfloat16, + 'float16': torch.float16, +} +args.torch_dtype = _dtype_map[args.dtype] + +# ── 1. Load & Sanitize RelBench data ───────────────────────────────── +print(f'Loading RelBench {args.dataset} dataset...') +dataset = get_dataset(args.dataset) +db = dataset.get_db() +data = from_relbench(db) + +# Replace SQL NULLs with zeros and normalize numeric features. +for node_type in data.node_types: + if hasattr(data[node_type], 'x') and data[node_type].x is not None: + x = data[node_type].x + x = torch.nan_to_num(x, nan=0.0) + std, mean = torch.std_mean(x, dim=0) + std = torch.where(std == 0, torch.ones_like(std), std) + data[node_type].x = (x - mean) / std + +print(f'Graph: {len(data.node_types)} node types, ' + f'{len(data.edge_types)} edge types') + + +# ── 2. Define Trainable Feature Projector ──────────────────────────── +class HeteroFeatureProjector(nn.Module): + """Projects heterogeneous node features to a common dimension. + + Uses ``HeteroDictLinear`` for node types with numeric features + and ``nn.Embedding`` for featureless structural tables. + """ + def __init__(self, data, common_dim: int): + super().__init__() + featured = {} + self.featureless = [] + for nt in data.node_types: + x = data[nt].get('x', None) + if x is not None and x.shape[1] > 0: + featured[nt] = x.shape[1] + else: + self.featureless.append(nt) + + self.lin = HeteroDictLinear(featured, common_dim) + self.embs = nn.ModuleDict({ + nt: nn.Embedding(data[nt].num_nodes, common_dim) + for nt in self.featureless + }) + + def forward(self, data): + """Return a dict of projected features, preserving autograd.""" + x_dict = {nt: data[nt].x for nt in self.lin.lins} + out = self.lin(x_dict) + res = {} + for nt in data.node_types: + if nt in out: + res[nt] = out[nt] + else: + res[nt] = self.embs[nt].weight + return res + + +projector = HeteroFeatureProjector(data, args.hidden) + +# ── 3. Extract Homogeneous Topology ────────────────────────────────── +# Topology (edge_index) is static, computed once. Node features (homo_x) +# are computed dynamically inside the training loop so that gradients +# flow back through the projector. +homo_topology = data.to_homogeneous() +homo_edge_index = homo_topology.edge_index +print(f'Homogeneous: edge_index={list(homo_edge_index.shape)}') + +# ── 4. Create synthetic Q&A pairs ─────────────────────────────────── +# These synthetic Q&A pairs are illustrative. +num_drivers = (data['drivers'].num_nodes + if 'drivers' in data.node_types else '?') +num_constructors = (data['constructors'].num_nodes + if 'constructors' in data.node_types else '?') +num_node_types = len(data.node_types) +num_edge_types = len(data.edge_types) + +qa_pairs = [ + ('How many drivers are in the dataset?', + f'There are {num_drivers} drivers in the Formula 1 dataset.'), + ('How many constructors are in the dataset?', + f'There are {num_constructors} constructors.'), + ('How many types of entities are in the graph?', + f'The graph has {num_node_types} node types and ' + f'{num_edge_types} edge types.'), + ('What entity types exist in the Formula 1 knowledge graph?', + f'The entity types include: {", ".join(data.node_types)}.'), + ('How are drivers connected to races?', + 'Drivers connect to races through results and qualifying entries.'), + ('What does this knowledge graph represent?', + 'This graph represents Formula 1 racing data including drivers, ' + 'teams, circuits, races, and their relationships.'), +] + +# ── 5. Build GRetriever model ──────────────────────────────────────── +print(f'\nInitializing GRetriever with LLM={args.llm}...') + +gnn = GAT( + in_channels=args.hidden, + hidden_channels=args.hidden, + num_layers=args.gnn_layers, + out_channels=args.hidden, +) + +llm = LLM( + model_name=args.llm, + n_gpus=args.n_gpus if args.n_gpus > 0 else None, + dtype=args.torch_dtype, + sys_prompt=( + 'You are an expert assistant that answers questions about ' + 'Formula 1 data using knowledge graph context. ' + 'Give concise, direct answers.' + ), +) + +model = GRetriever(llm=llm, gnn=gnn) +print('Model initialized.') + +# Move model components to the LLM device +device = model.llm.device +model.gnn = model.gnn.to(device) +projector = projector.to(device) +homo_edge_index = homo_edge_index.to(device) +data = data.to(device) +print(f'Using device: {device}') + +# ── 6. Training loop ──────────────────────────────────────────────── +# Include projector parameters so the feature embeddings actually learn. +params = [p for p in model.parameters() if p.requires_grad] +params += list(projector.parameters()) +optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=0.05) + +context_str = ( + f'This is a Formula 1 knowledge graph with {num_node_types} entity ' + f'types ({", ".join(data.node_types)}).' +) + +print(f'\nTraining {args.epochs} epochs on {len(qa_pairs)} samples...') +model.train() +projector.train() + +for epoch in range(1, args.epochs + 1): + total_loss = 0.0 + + for q, a in qa_pairs: + optimizer.zero_grad() + + # Dynamic projection: compute inside the loop so gradients + # flow back through the projector. + projected_dict = projector(data) + # Stack in data.node_types order (same order as to_homogeneous) + homo_x = torch.cat( + [projected_dict[nt] for nt in data.node_types], dim=0) + + # Single-graph paradigm: all nodes belong to batch index 0 + batch_idx = torch.zeros( + homo_x.size(0), dtype=torch.long, device=device) + + loss = model( + question=[q], + x=homo_x, + edge_index=homo_edge_index, + batch=batch_idx, + label=[a], + additional_text_context=[context_str], + ) + + if loss.isnan(): + raise RuntimeError( + f'NaN loss on question: "{q}". ' + 'Check data normalization or reduce learning rate.') + + loss.backward() + torch.nn.utils.clip_grad_norm_(params, 0.1) + optimizer.step() + total_loss += loss.item() + + avg_loss = total_loss / len(qa_pairs) + print(f'Epoch {epoch:02d}: Loss={avg_loss:.4f}') + +# ── 7. Inference demo ──────────────────────────────────────────────── +print('\nInference:') +model.eval() +projector.eval() + +# Compute static features for inference +with torch.no_grad(): + projected_dict = projector(data) + homo_x = torch.cat( + [projected_dict[nt] for nt in data.node_types], dim=0) + +test_questions = [ + 'How many drivers are in this Formula 1 dataset?', + 'What entity types exist in the graph?', +] + +for test_q in test_questions: + with torch.no_grad(): + response = model.inference( + question=[test_q], + x=homo_x, + edge_index=homo_edge_index, + batch=torch.zeros(homo_x.size(0), dtype=torch.long, + device=device), + additional_text_context=[context_str], + max_out_tokens=64, + ) + print(f'Q: {test_q}') + print(f'A: {response[0]}') + print() From b66996cbd0146450dd44279c036457fb8bccfe37 Mon Sep 17 00:00:00 2001 From: AJamal27891 Date: Wed, 29 Apr 2026 20:59:08 +0300 Subject: [PATCH 11/14] Polish PR2 relbench_gretriever demo: simplify docs, refine comments, and tighten example Q&A --- examples/llm/relbench_gretriever.py | 125 +++++++++++++++------------- 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/examples/llm/relbench_gretriever.py b/examples/llm/relbench_gretriever.py index 8e2ed9891e45..984ab9f1e7cb 100644 --- a/examples/llm/relbench_gretriever.py +++ b/examples/llm/relbench_gretriever.py @@ -1,22 +1,11 @@ -"""Example demonstrating how to bridge ``from_relbench`` heterogeneous -graphs to GRetriever for graph-augmented question answering. - -This example loads the Formula 1 RelBench dataset, sanitizes the data, -projects all node types into a shared latent space (handling featureless -structural tables via learned embeddings), converts to homogeneous format, -and feeds the result into GRetriever. - -.. note:: - Calling ``to_homogeneous()`` directly on RelBench data silently - drops ALL node features (``x=None``) when any table lacks numeric - columns. This example shows the correct pattern: sanitize, project - all types to a common dimension, then convert. - -.. note:: - Due to a known upstream issue in PyG ``llm.py`` with - ``transformers >= 5.0``, this example currently requires - ``transformers 4.x``. - (``pip install "transformers>=4.51,<5.0"``) +"""Minimal example showing how to use ``from_relbench`` output with GRetriever. + +This script loads Formula 1 data from RelBench, sanitizes numeric node +features, projects heterogeneous node types into a shared latent space, +converts the graph to homogeneous format, and passes it to GRetriever. + +The goal is to demonstrate the projection-first pattern required before +calling ``to_homogeneous()`` on RelBench-derived graphs. Requirements: ``pip install relbench "transformers>=4.51,<5.0" sentencepiece @@ -36,7 +25,23 @@ from torch_geometric.nn import GAT, HeteroDictLinear from torch_geometric.utils import from_relbench -# ── CLI ────────────────────────────────────────────────────────────── +try: + import transformers +except ImportError as exc: + raise RuntimeError( + 'The `transformers` package is required. Install it with: ' + '`pip install "transformers>=4.51,<5.0"`.' + ) from exc + +transformers_version = tuple(int(x) for x in transformers.__version__.split('.')[:2]) +if transformers_version[0] >= 5: + raise RuntimeError( + f'Unsupported transformers version {transformers.__version__}. ' + 'This example requires transformers 4.x. Install with: ' + '`pip install "transformers>=4.51,<5.0"`.' + ) + +# CLI options parser = argparse.ArgumentParser( description='RelBench -> GRetriever example.') parser.add_argument('--dataset', type=str, default='rel-f1', @@ -66,7 +71,7 @@ } args.torch_dtype = _dtype_map[args.dtype] -# ── 1. Load & Sanitize RelBench data ───────────────────────────────── +# Load and sanitize RelBench data print(f'Loading RelBench {args.dataset} dataset...') dataset = get_dataset(args.dataset) db = dataset.get_db() @@ -85,7 +90,7 @@ f'{len(data.edge_types)} edge types') -# ── 2. Define Trainable Feature Projector ──────────────────────────── +# Define the projector for heterogeneous node features class HeteroFeatureProjector(nn.Module): """Projects heterogeneous node features to a common dimension. @@ -118,47 +123,47 @@ def forward(self, data): if nt in out: res[nt] = out[nt] else: + # These learned embeddings are only valid for the current graph. + # They do not generalize to unseen nodes in another graph. res[nt] = self.embs[nt].weight return res projector = HeteroFeatureProjector(data, args.hidden) -# ── 3. Extract Homogeneous Topology ────────────────────────────────── -# Topology (edge_index) is static, computed once. Node features (homo_x) -# are computed dynamically inside the training loop so that gradients -# flow back through the projector. +# Extract the homogeneous graph topology +# The edge structure is static, while node features are recomputed inside +# the training loop so gradients can propagate through the projector. homo_topology = data.to_homogeneous() homo_edge_index = homo_topology.edge_index print(f'Homogeneous: edge_index={list(homo_edge_index.shape)}') -# ── 4. Create synthetic Q&A pairs ─────────────────────────────────── -# These synthetic Q&A pairs are illustrative. -num_drivers = (data['drivers'].num_nodes - if 'drivers' in data.node_types else '?') -num_constructors = (data['constructors'].num_nodes - if 'constructors' in data.node_types else '?') -num_node_types = len(data.node_types) -num_edge_types = len(data.edge_types) - +# Build a small set of example questions for the GRetriever demo. +# These are meant to show the relationship between node types and edge data, +# not to model a full retrieval task. qa_pairs = [ - ('How many drivers are in the dataset?', - f'There are {num_drivers} drivers in the Formula 1 dataset.'), - ('How many constructors are in the dataset?', - f'There are {num_constructors} constructors.'), - ('How many types of entities are in the graph?', - f'The graph has {num_node_types} node types and ' - f'{num_edge_types} edge types.'), - ('What entity types exist in the Formula 1 knowledge graph?', - f'The entity types include: {", ".join(data.node_types)}.'), - ('How are drivers connected to races?', - 'Drivers connect to races through results and qualifying entries.'), - ('What does this knowledge graph represent?', - 'This graph represents Formula 1 racing data including drivers, ' - 'teams, circuits, races, and their relationships.'), + ( + 'Which entity types appear in this Formula 1 graph?', + 'The graph contains node types such as drivers, constructors, circuits, ' + 'races, and teams.', + ), + ( + 'How is the driver-to-race connection represented?', + 'Drivers are linked to races through result and qualifying edges.', + ), + ( + 'What role do constructors have in the dataset?', + 'Constructors are part of the race entry structure and connect ' + 'teams with drivers.', + ), + ( + 'Why do we project all node types before calling to_homogeneous?', + 'The projection creates a shared embedding space so GRetriever can ' + 'process the graph as a single homogeneous tensor.', + ), ] -# ── 5. Build GRetriever model ──────────────────────────────────────── +# Build the GRetriever model print(f'\nInitializing GRetriever with LLM={args.llm}...') gnn = GAT( @@ -190,15 +195,16 @@ def forward(self, data): data = data.to(device) print(f'Using device: {device}') -# ── 6. Training loop ──────────────────────────────────────────────── +# Training loop # Include projector parameters so the feature embeddings actually learn. params = [p for p in model.parameters() if p.requires_grad] params += list(projector.parameters()) optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=0.05) context_str = ( - f'This is a Formula 1 knowledge graph with {num_node_types} entity ' - f'types ({", ".join(data.node_types)}).' + 'This Formula 1 knowledge graph includes drivers, constructors, circuits, ' + 'races, and teams, with edges representing race results, qualifying, and ' + 'entity relationships.' ) print(f'\nTraining {args.epochs} epochs on {len(qa_pairs)} samples...') @@ -214,11 +220,18 @@ def forward(self, data): # Dynamic projection: compute inside the loop so gradients # flow back through the projector. projected_dict = projector(data) - # Stack in data.node_types order (same order as to_homogeneous) + # Stack in data.node_types order, then verify the total node count. homo_x = torch.cat( [projected_dict[nt] for nt in data.node_types], dim=0) + assert homo_x.size(0) == homo_topology.num_nodes, ( + 'Expected projected homo_x to have the same number of nodes as ' + 'the homogeneous topology. If this fails, the node ordering ' + 'assumption is incorrect.' + ) - # Single-graph paradigm: all nodes belong to batch index 0 + # Single-graph paradigm: all nodes belong to batch index 0. + # For mini-batched graph training, a Batch object with graph indices + # would be required instead. batch_idx = torch.zeros( homo_x.size(0), dtype=torch.long, device=device) @@ -244,7 +257,7 @@ def forward(self, data): avg_loss = total_loss / len(qa_pairs) print(f'Epoch {epoch:02d}: Loss={avg_loss:.4f}') -# ── 7. Inference demo ──────────────────────────────────────────────── +# Inference demo print('\nInference:') model.eval() projector.eval() From 46061ab3d8db2014b4ead3b267c0c7ca6011d510 Mon Sep 17 00:00:00 2001 From: AJamal27891 Date: Thu, 30 Apr 2026 05:50:26 +0300 Subject: [PATCH 12/14] Fix relbench_gretriever example syntax and ensure PR2 branch validation --- examples/llm/relbench_gretriever.py | 67 ++++++++++++----------------- 1 file changed, 28 insertions(+), 39 deletions(-) diff --git a/examples/llm/relbench_gretriever.py b/examples/llm/relbench_gretriever.py index 984ab9f1e7cb..3bb63779741f 100644 --- a/examples/llm/relbench_gretriever.py +++ b/examples/llm/relbench_gretriever.py @@ -21,7 +21,7 @@ import torch.nn as nn from relbench.datasets import get_dataset -from torch_geometric.llm.models import GRetriever, LLM +from torch_geometric.llm.models import LLM, GRetriever from torch_geometric.nn import GAT, HeteroDictLinear from torch_geometric.utils import from_relbench @@ -30,33 +30,28 @@ except ImportError as exc: raise RuntimeError( 'The `transformers` package is required. Install it with: ' - '`pip install "transformers>=4.51,<5.0"`.' - ) from exc + '`pip install "transformers>=4.51,<5.0"`.') from exc -transformers_version = tuple(int(x) for x in transformers.__version__.split('.')[:2]) +transformers_version = tuple( + int(x) for x in transformers.__version__.split('.')[:2]) if transformers_version[0] >= 5: raise RuntimeError( f'Unsupported transformers version {transformers.__version__}. ' 'This example requires transformers 4.x. Install with: ' - '`pip install "transformers>=4.51,<5.0"`.' - ) + '`pip install "transformers>=4.51,<5.0"`.') # CLI options -parser = argparse.ArgumentParser( - description='RelBench -> GRetriever example.') +parser = argparse.ArgumentParser(description='RelBench -> GRetriever example.') parser.add_argument('--dataset', type=str, default='rel-f1', help='RelBench dataset name (default: rel-f1)') -parser.add_argument('--llm', type=str, - default='Qwen/Qwen2-0.5B', +parser.add_argument('--llm', type=str, default='Qwen/Qwen2-0.5B', help='HuggingFace LLM model name') parser.add_argument('--hidden', type=int, default=64, help='Common projection + GNN hidden dim') parser.add_argument('--gnn_layers', type=int, default=2, help='Number of GAT layers') -parser.add_argument('--epochs', type=int, default=5, - help='Training epochs') -parser.add_argument('--lr', type=float, default=1e-4, - help='Learning rate') +parser.add_argument('--epochs', type=int, default=5, help='Training epochs') +parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') parser.add_argument('--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16', 'float16'], help='LLM dtype (use float32 for CPU-only)') @@ -110,7 +105,8 @@ def __init__(self, data, common_dim: int): self.lin = HeteroDictLinear(featured, common_dim) self.embs = nn.ModuleDict({ - nt: nn.Embedding(data[nt].num_nodes, common_dim) + nt: + nn.Embedding(data[nt].num_nodes, common_dim) for nt in self.featureless }) @@ -123,8 +119,8 @@ def forward(self, data): if nt in out: res[nt] = out[nt] else: - # These learned embeddings are only valid for the current graph. - # They do not generalize to unseen nodes in another graph. + # These learned embeddings are only valid for this graph. + # They do not generalize to unseen nodes from another graph. res[nt] = self.embs[nt].weight return res @@ -144,8 +140,8 @@ def forward(self, data): qa_pairs = [ ( 'Which entity types appear in this Formula 1 graph?', - 'The graph contains node types such as drivers, constructors, circuits, ' - 'races, and teams.', + 'The graph contains node types such as drivers, constructors, ' + 'circuits, races, and teams.', ), ( 'How is the driver-to-race connection represented?', @@ -177,15 +173,12 @@ def forward(self, data): model_name=args.llm, n_gpus=args.n_gpus if args.n_gpus > 0 else None, dtype=args.torch_dtype, - sys_prompt=( - 'You are an expert assistant that answers questions about ' - 'Formula 1 data using knowledge graph context. ' - 'Give concise, direct answers.' - ), + sys_prompt=('You are an expert assistant that answers questions about ' + 'Formula 1 data using knowledge graph context. ' + 'Give concise, direct answers.'), ) model = GRetriever(llm=llm, gnn=gnn) -print('Model initialized.') # Move model components to the LLM device device = model.llm.device @@ -204,8 +197,7 @@ def forward(self, data): context_str = ( 'This Formula 1 knowledge graph includes drivers, constructors, circuits, ' 'races, and teams, with edges representing race results, qualifying, and ' - 'entity relationships.' -) + 'entity relationships.') print(f'\nTraining {args.epochs} epochs on {len(qa_pairs)} samples...') model.train() @@ -221,19 +213,18 @@ def forward(self, data): # flow back through the projector. projected_dict = projector(data) # Stack in data.node_types order, then verify the total node count. - homo_x = torch.cat( - [projected_dict[nt] for nt in data.node_types], dim=0) + homo_x = torch.cat([projected_dict[nt] for nt in data.node_types], + dim=0) assert homo_x.size(0) == homo_topology.num_nodes, ( 'Expected projected homo_x to have the same number of nodes as ' 'the homogeneous topology. If this fails, the node ordering ' - 'assumption is incorrect.' - ) + 'assumption is incorrect.') # Single-graph paradigm: all nodes belong to batch index 0. # For mini-batched graph training, a Batch object with graph indices # would be required instead. - batch_idx = torch.zeros( - homo_x.size(0), dtype=torch.long, device=device) + batch_idx = torch.zeros(homo_x.size(0), dtype=torch.long, + device=device) loss = model( question=[q], @@ -265,12 +256,11 @@ def forward(self, data): # Compute static features for inference with torch.no_grad(): projected_dict = projector(data) - homo_x = torch.cat( - [projected_dict[nt] for nt in data.node_types], dim=0) + homo_x = torch.cat([projected_dict[nt] for nt in data.node_types], dim=0) test_questions = [ - 'How many drivers are in this Formula 1 dataset?', - 'What entity types exist in the graph?', + 'Which entity types appear in this Formula 1 graph?', + 'Why do we project all node types before calling to_homogeneous?', ] for test_q in test_questions: @@ -279,8 +269,7 @@ def forward(self, data): question=[test_q], x=homo_x, edge_index=homo_edge_index, - batch=torch.zeros(homo_x.size(0), dtype=torch.long, - device=device), + batch=torch.zeros(homo_x.size(0), dtype=torch.long, device=device), additional_text_context=[context_str], max_out_tokens=64, ) From 4f2b2f5bd78d258a24f1d77fefca4febc887e7fd Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 4 Jun 2026 10:45:21 -0700 Subject: [PATCH 13/14] Update example, fix dtype issue --- examples/llm/relbench_gretriever.py | 20 ++++++++++---------- torch_geometric/llm/models/g_retriever.py | 1 + 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/llm/relbench_gretriever.py b/examples/llm/relbench_gretriever.py index 3bb63779741f..665730f33f7d 100644 --- a/examples/llm/relbench_gretriever.py +++ b/examples/llm/relbench_gretriever.py @@ -19,11 +19,13 @@ import torch import torch.nn as nn +from packaging.version import Version from relbench.datasets import get_dataset +from torch_geometric.contrib.utils import from_relbench +from torch_geometric.data import HeteroData from torch_geometric.llm.models import LLM, GRetriever from torch_geometric.nn import GAT, HeteroDictLinear -from torch_geometric.utils import from_relbench try: import transformers @@ -32,16 +34,14 @@ 'The `transformers` package is required. Install it with: ' '`pip install "transformers>=4.51,<5.0"`.') from exc -transformers_version = tuple( - int(x) for x in transformers.__version__.split('.')[:2]) -if transformers_version[0] >= 5: +if Version(transformers.__version__) >= Version('5.0'): raise RuntimeError( f'Unsupported transformers version {transformers.__version__}. ' 'This example requires transformers 4.x. Install with: ' '`pip install "transformers>=4.51,<5.0"`.') # CLI options -parser = argparse.ArgumentParser(description='RelBench -> GRetriever example.') +parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='rel-f1', help='RelBench dataset name (default: rel-f1)') parser.add_argument('--llm', type=str, default='Qwen/Qwen2-0.5B', @@ -55,16 +55,16 @@ parser.add_argument('--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16', 'float16'], help='LLM dtype (use float32 for CPU-only)') -parser.add_argument('--n_gpus', type=int, default=1, +parser.add_argument('--n_gpus', type=int, default=torch.cuda.device_count(), help='Number of GPUs for the LLM (0 for CPU)') args = parser.parse_args() -_dtype_map = { +dtype_map = { 'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16, } -args.torch_dtype = _dtype_map[args.dtype] +args.torch_dtype = dtype_map[args.dtype] # Load and sanitize RelBench data print(f'Loading RelBench {args.dataset} dataset...') @@ -92,7 +92,7 @@ class HeteroFeatureProjector(nn.Module): Uses ``HeteroDictLinear`` for node types with numeric features and ``nn.Embedding`` for featureless structural tables. """ - def __init__(self, data, common_dim: int): + def __init__(self, data: HeteroData, common_dim: int) -> None: super().__init__() featured = {} self.featureless = [] @@ -110,7 +110,7 @@ def __init__(self, data, common_dim: int): for nt in self.featureless }) - def forward(self, data): + def forward(self, data: HeteroData) -> dict[str, torch.Tensor]: """Return a dict of projected features, preserving autograd.""" x_dict = {nt: data[nt].x for nt in self.lin.lins} out = self.lin(x_dict) diff --git a/torch_geometric/llm/models/g_retriever.py b/torch_geometric/llm/models/g_retriever.py index d7abc934f5f3..180ab1f6f2b9 100644 --- a/torch_geometric/llm/models/g_retriever.py +++ b/torch_geometric/llm/models/g_retriever.py @@ -208,6 +208,7 @@ def inference( if self.gnn is not None: x = self.encode(x, edge_index, batch, edge_attr) x = self.projector(x) + x = self._align_dtype(x, self.llm_generator) xs = x.split(1, dim=0) # Handle case where theres more than one embedding for each sample From 91f0212af8c9bc09a8e9f4e2e605ff07596bffb7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 4 Jun 2026 10:49:14 -0700 Subject: [PATCH 14/14] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1415c411117e..3e4ba467a966 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Fixed +- Fixed dtype mismatch in `GRetriever.inference` ([#10681](https://github.com/pyg-team/pytorch_geometric/pull/10681)) - Fix MovieLens dataset incompatibility with `sentence-transformers>=5.0.0` ([#10668](https://github.com/pyg-team/pytorch_geometric/pull/10668)) - Removed an unnecessary device synchronization in `torch_geometric.utils.softmax` ([#10499](https://github.com/pyg-team/pytorch_geometric/pull/10499)) - Fixed loading of legacy HuggingFace BERT checkpoints ([#10631](https://github.com/pyg-team/pytorch_geometric/pull/10631))