Skip to content

Add GRetriever RelBench example demo and from_relbench integration#10681

Draft
AJamal27891 wants to merge 10 commits into
pyg-team:masterfrom
AJamal27891:pr-10353-part2-gretriever
Draft

Add GRetriever RelBench example demo and from_relbench integration#10681
AJamal27891 wants to merge 10 commits into
pyg-team:masterfrom
AJamal27891:pr-10353-part2-gretriever

Conversation

@AJamal27891
Copy link
Copy Markdown
Contributor

@AJamal27891 AJamal27891 commented Apr 29, 2026

Adds examples/llm/relbench_gretriever.py, demo the integration of from_relbench utility output with GRetriever.

The example highlights a key pattern for heterogeneous-to-homogeneous conversion: calling to_homogeneous() on raw RelBench-derived HeteroData can lose node features or force padding when some node types have no attributes. To avoid this, all node types are first projected into a shared embedding space via a HeteroFeatureProjector (using HeteroDictLinear for feature-bearing types and nn.Embedding for featureless ones), and only then converted.

This pattern preserves node features and enables reliable use of heterogeneous graphs in homogeneous GNN pipelines such as GRetriever.

Includes:

  • End-to-end example from RelBench database → HeteroData → projected features → homogeneous graph → GRetriever
  • Handling of featureless node types via learned embeddings
  • Minimal training and inference loop for demo

Depends on #10628.
Ref: #10353.
Partially Closes issue #9839.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 1, 2026

Codecov Report

❌ Patch coverage is 97.72727% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 81.27%. Comparing base (c211214) to head (c717340).
⚠️ Report is 201 commits behind head on master.

Files with missing lines Patch % Lines
torch_geometric/utils/relbench.py 97.67% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #10681      +/-   ##
==========================================
- Coverage   86.11%   81.27%   -4.84%     
==========================================
  Files         496      512      +16     
  Lines       33655    37863    +4208     
==========================================
+ Hits        28981    30775    +1794     
- Misses       4674     7088    +2414     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AJamal27891 AJamal27891 changed the title Pr 10353 part2 gretriever Add GRetriever RelBench example demo and from_relbench integration May 1, 2026
Copy link
Copy Markdown
Contributor

@puririshi98 puririshi98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AJamal27891 can you include a log of an e2e run of the example added?

@akihironitta akihironitta added this to the 2.8.0 milestone May 1, 2026
@AJamal27891
Copy link
Copy Markdown
Contributor Author

@AJamal27891 can you include a log of an e2e run of the example added?

root@64af3f32cfc8:/workspace# python /workspace/pytorch_geometric/examples/llm/relbench_gretriever.py --epochs 10 --llm Qwen/Qwen2-0.5B
Loading RelBench rel-f1 dataset...
Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.07 seconds.
Graph: 9 node types, 26 edge types
Homogeneous: edge_index=[2, 338842]

Initializing GRetriever with LLM=Qwen/Qwen2-0.5B...
Setting up 'Qwen/Qwen2-0.5B' with configuration: {'revision': 'main', 'max_memory': {0: '3GiB'}, 'low_cpu_mem_usage': True, 'device_map': 'auto', 'dtype': torch.bfloat16}
/usr/local/lib/python3.12/dist-packages/torch/library.py:357: UserWarning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
    registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926
  dispatch key: ADInplaceOrView
  previous kernel: no debug info
       new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
  self.m.impl(
Skipping import of cpp extensions due to incompatible torch version 2.10.0a0+a36e1d39eb.nv26.01.42222806 for torchao version 0.15.0             Please see https://github.com/pytorch/ao/issues/2919 for more info
Using device: cuda:0

Training 10 epochs on 4 samples...
Epoch 01: Loss=3.5860
Epoch 02: Loss=2.2119
Epoch 03: Loss=0.9973
Epoch 04: Loss=1.1809
Epoch 05: Loss=0.6012
Epoch 06: Loss=1.3658
Epoch 07: Loss=0.2758
Epoch 08: Loss=0.2540
Epoch 09: Loss=0.4419
Epoch 10: Loss=0.5418

Inference:
Q: Which entity types appear in this Formula 1 graph?
A: The graph contains node types such as drivers, constructors circuits races, qualifying edges.

Q: Why do we project all node types before calling to_homogeneous?
A: The projection creates a shared embedding space so GRetriever can process the graph as a single homogeneous tensor.

@AJamal27891 AJamal27891 requested a review from puririshi98 May 2, 2026 11:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants