Add GRetriever RelBench example demo and from_relbench integration#10681
Draft
AJamal27891 wants to merge 10 commits into
Draft
Add GRetriever RelBench example demo and from_relbench integration#10681AJamal27891 wants to merge 10 commits into
AJamal27891 wants to merge 10 commits into
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #10681 +/- ##
==========================================
- Coverage 86.11% 81.27% -4.84%
==========================================
Files 496 512 +16
Lines 33655 37863 +4208
==========================================
+ Hits 28981 30775 +1794
- Misses 4674 7088 +2414 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
puririshi98
requested changes
May 1, 2026
Contributor
puririshi98
left a comment
There was a problem hiding this comment.
@AJamal27891 can you include a log of an e2e run of the example added?
Contributor
Author
root@64af3f32cfc8:/workspace# python /workspace/pytorch_geometric/examples/llm/relbench_gretriever.py --epochs 10 --llm Qwen/Qwen2-0.5B
Loading RelBench rel-f1 dataset...
Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.07 seconds.
Graph: 9 node types, 26 edge types
Homogeneous: edge_index=[2, 338842]
Initializing GRetriever with LLM=Qwen/Qwen2-0.5B...
Setting up 'Qwen/Qwen2-0.5B' with configuration: {'revision': 'main', 'max_memory': {0: '3GiB'}, 'low_cpu_mem_usage': True, 'device_map': 'auto', 'dtype': torch.bfloat16}
/usr/local/lib/python3.12/dist-packages/torch/library.py:357: UserWarning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926
dispatch key: ADInplaceOrView
previous kernel: no debug info
new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
self.m.impl(
Skipping import of cpp extensions due to incompatible torch version 2.10.0a0+a36e1d39eb.nv26.01.42222806 for torchao version 0.15.0 Please see https://github.com/pytorch/ao/issues/2919 for more info
Using device: cuda:0
Training 10 epochs on 4 samples...
Epoch 01: Loss=3.5860
Epoch 02: Loss=2.2119
Epoch 03: Loss=0.9973
Epoch 04: Loss=1.1809
Epoch 05: Loss=0.6012
Epoch 06: Loss=1.3658
Epoch 07: Loss=0.2758
Epoch 08: Loss=0.2540
Epoch 09: Loss=0.4419
Epoch 10: Loss=0.5418
Inference:
Q: Which entity types appear in this Formula 1 graph?
A: The graph contains node types such as drivers, constructors circuits races, qualifying edges.
Q: Why do we project all node types before calling to_homogeneous?
A: The projection creates a shared embedding space so GRetriever can process the graph as a single homogeneous tensor. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds
examples/llm/relbench_gretriever.py, demo the integration offrom_relbenchutility output withGRetriever.The example highlights a key pattern for heterogeneous-to-homogeneous conversion: calling
to_homogeneous()on raw RelBench-derivedHeteroDatacan lose node features or force padding when some node types have no attributes. To avoid this, all node types are first projected into a shared embedding space via aHeteroFeatureProjector(usingHeteroDictLinearfor feature-bearing types andnn.Embeddingfor featureless ones), and only then converted.This pattern preserves node features and enables reliable use of heterogeneous graphs in homogeneous GNN pipelines such as
GRetriever.Includes:
HeteroData→ projected features → homogeneous graph →GRetrieverDepends on #10628.
Ref: #10353.
Partially Closes issue #9839.